You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hu...@apache.org on 2022/03/02 19:55:01 UTC
[spark] branch branch-3.1 updated: [SPARK-36553][ML] KMeans avoid compute auxiliary statistics for large K
This is an automated email from the ASF dual-hosted git repository.
huaxingao pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new 357d3b2 [SPARK-36553][ML] KMeans avoid compute auxiliary statistics for large K
357d3b2 is described below
commit 357d3b24173405cdf915be60f2cebe442fa31536
Author: Ruifeng Zheng <ru...@foxmail.com>
AuthorDate: Wed Mar 2 11:51:06 2022 -0800
[SPARK-36553][ML] KMeans avoid compute auxiliary statistics for large K
### What changes were proposed in this pull request?
SPARK-31007 introduce an auxiliary statistics to speed up computation in KMeasn.
However, it needs a array of size `k * (k + 1) / 2`, which may cause overflow or OOM when k is too large.
So we should skip this optimization in this case.
### Why are the changes needed?
avoid overflow or OOM when k is too large (like 50,000)
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existing testsuites
Closes #35457 from zhengruifeng/kmean_k_limit.
Authored-by: Ruifeng Zheng <ru...@foxmail.com>
Signed-off-by: huaxingao <hu...@apple.com>
(cherry picked from commit ad5427ebe644fc01a9b4c19a48f902f584245edf)
Signed-off-by: huaxingao <hu...@apple.com>
---
.../spark/mllib/clustering/DistanceMeasure.scala | 23 ++++++++++++++++++++++
.../org/apache/spark/mllib/clustering/KMeans.scala | 15 ++++++++++----
.../spark/mllib/clustering/KMeansModel.scala | 11 +++++++++--
3 files changed, 43 insertions(+), 6 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
index 9ac473a..e4c29a7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
@@ -118,6 +118,24 @@ private[spark] abstract class DistanceMeasure extends Serializable {
}
/**
+ * @param centers the clustering centers
+ * @param statistics optional statistics to accelerate the computation, which should not
+ * change the result.
+ * @param point given point
+ * @return the index of the closest center to the given point, as well as the cost.
+ */
+ def findClosest(
+ centers: Array[VectorWithNorm],
+ statistics: Option[Array[Double]],
+ point: VectorWithNorm): (Int, Double) = {
+ if (statistics.nonEmpty) {
+ findClosest(centers, statistics.get, point)
+ } else {
+ findClosest(centers, point)
+ }
+ }
+
+ /**
* @return the index of the closest center to the given point, as well as the cost.
*/
def findClosest(
@@ -253,6 +271,11 @@ object DistanceMeasure {
case _ => false
}
}
+
+ private[clustering] def shouldComputeStatistics(k: Int): Boolean = k < 1000
+
+ private[clustering] def shouldComputeStatisticsLocally(k: Int, numFeatures: Int): Boolean =
+ k.toLong * k * numFeatures < 1000000
}
private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 76e2928..c140b1b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -269,15 +269,22 @@ class KMeans private (
instr.foreach(_.logNumFeatures(numFeatures))
- val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L
+ val shouldComputeStats =
+ DistanceMeasure.shouldComputeStatistics(centers.length)
+ val shouldComputeStatsLocally =
+ DistanceMeasure.shouldComputeStatisticsLocally(centers.length, numFeatures)
// Execute iterations of Lloyd's algorithm until converged
while (iteration < maxIterations && !converged) {
val bcCenters = sc.broadcast(centers)
- val stats = if (shouldDistributed) {
- distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters)
+ val stats = if (shouldComputeStats) {
+ if (shouldComputeStatsLocally) {
+ Some(distanceMeasureInstance.computeStatistics(centers))
+ } else {
+ Some(distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters))
+ }
} else {
- distanceMeasureInstance.computeStatistics(centers)
+ None
}
val bcStats = sc.broadcast(stats)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index a24493b..64b3521 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -50,9 +50,16 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
// TODO: computation of statistics may take seconds, so save it to KMeansModel in training
@transient private lazy val statistics = if (clusterCenters == null) {
- null
+ None
} else {
- distanceMeasureInstance.computeStatistics(clusterCentersWithNorm)
+ val k = clusterCenters.length
+ val numFeatures = clusterCenters.head.size
+ if (DistanceMeasure.shouldComputeStatistics(k) &&
+ DistanceMeasure.shouldComputeStatisticsLocally(k, numFeatures)) {
+ Some(distanceMeasureInstance.computeStatistics(clusterCentersWithNorm))
+ } else {
+ None
+ }
}
@Since("2.4.0")
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org