You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/05/03 12:33:09 UTC

[GitHub] [spark] zhengruifeng commented on a diff in pull request #36437: [SPARK-30661][ML][PYTHON] KMeans blockify input vectors

zhengruifeng commented on code in PR #36437:
URL: https://github.com/apache/spark/pull/36437#discussion_r863725988


##########
mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala:
##########
@@ -398,3 +652,133 @@ class KMeansSummary private[clustering] (
     numIter: Int,
     @Since("2.4.0") val trainingCost: Double)
   extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter)
+
+/**
+ * KMeansAggregator computes the distances and updates the centers for blocks
+ * in sparse or dense matrix in an online fashion.
+ * @param centerMatrix The matrix containing center vectors.
+ * @param k The number of clusters.
+ * @param numFeatures The number of features.
+ * @param distanceMeasure The distance measure.
+ *                        When 'euclidean' is chosen, the instance blocks should contains
+ *                        the squared norms in the labels field;
+ *                        When 'cosine' is chosen, the vectors should be already normalized.
+ */
+private class KMeansAggregator (
+    val centerMatrix: DenseMatrix,
+    val k: Int,
+    val numFeatures: Int,
+    val distanceMeasure: String) extends Serializable {
+  import KMeans.{EUCLIDEAN, COSINE}
+
+  def weightSum: Double = weightSumVec.values.sum
+
+  var costSum = 0.0
+  var count = 0L
+  val weightSumVec = new DenseVector(Array.ofDim[Double](k))
+  val sumMat = new DenseMatrix(k, numFeatures, Array.ofDim[Double](k * numFeatures))
+
+  @transient private lazy val centerSquaredNorms = {
+    distanceMeasure match {
+      case EUCLIDEAN =>
+        centerMatrix.rowIter.map(center => center.dot(center)).toArray
+      case COSINE => null
+    }
+  }
+
+  // avoid reallocating a dense matrix (size x k) for each instance block
+  @transient private var buffer: Array[Double] = _

Review Comment:
   I found it crucial to avoid reallocating the buffer, specially when k is a relative large value.
   So blas is also modified to make GEMM can output to arrays besides matrices.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org