You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2020/05/25 14:19:42 UTC

[spark] branch master updated: [SPARK-31734][ML][PYSPARK] Add weight support in ClusteringEvaluator

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d400777  [SPARK-31734][ML][PYSPARK] Add weight support in ClusteringEvaluator
d400777 is described below

commit d4007776f2dd85f03f3811ab8ca711f221f62c00
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Mon May 25 09:18:08 2020 -0500

    [SPARK-31734][ML][PYSPARK] Add weight support in ClusteringEvaluator
    
    ### What changes were proposed in this pull request?
    Add weight support in ClusteringEvaluator
    
    ### Why are the changes needed?
    Currently, BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator support instance weight, but ClusteringEvaluator doesn't, so we will add instance weight support in ClusteringEvaluator.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    ClusteringEvaluator.setWeightCol
    
    ### How was this patch tested?
    add new unit test
    
    Closes #28553 from huaxingao/weight_evaluator.
    
    Authored-by: Huaxin Gao <hu...@us.ibm.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../spark/ml/evaluation/ClusteringEvaluator.scala  |  34 ++++--
 .../spark/ml/evaluation/ClusteringMetrics.scala    | 128 ++++++++++++---------
 .../ml/evaluation/ClusteringEvaluatorSuite.scala   |  43 ++++++-
 python/pyspark/ml/evaluation.py                    |  29 ++++-
 4 files changed, 167 insertions(+), 67 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
index 63b99a0..19790fd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
@@ -19,10 +19,11 @@ package org.apache.spark.ml.evaluation
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
+import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.Dataset
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
 
 /**
  * Evaluator for clustering results.
@@ -34,7 +35,8 @@ import org.apache.spark.sql.functions.col
  */
 @Since("2.3.0")
 class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: String)
-  extends Evaluator with HasPredictionCol with HasFeaturesCol with DefaultParamsWritable {
+  extends Evaluator with HasPredictionCol with HasFeaturesCol with HasWeightCol
+    with DefaultParamsWritable {
 
   @Since("2.3.0")
   def this() = this(Identifiable.randomUID("cluEval"))
@@ -53,6 +55,10 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
   @Since("2.3.0")
   def setFeaturesCol(value: String): this.type = set(featuresCol, value)
 
+  /** @group setParam */
+  @Since("3.1.0")
+  def setWeightCol(value: String): this.type = set(weightCol, value)
+
   /**
    * param for metric name in evaluation
    * (supports `"silhouette"` (default))
@@ -116,12 +122,26 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
    */
   @Since("3.1.0")
   def getMetrics(dataset: Dataset[_]): ClusteringMetrics = {
-    SchemaUtils.validateVectorCompatibleColumn(dataset.schema, $(featuresCol))
-    SchemaUtils.checkNumericType(dataset.schema, $(predictionCol))
+    val schema = dataset.schema
+    SchemaUtils.validateVectorCompatibleColumn(schema, $(featuresCol))
+    SchemaUtils.checkNumericType(schema, $(predictionCol))
+    if (isDefined(weightCol)) {
+      SchemaUtils.checkNumericType(schema, $(weightCol))
+    }
+
+    val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
 
     val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol))
-    val df = dataset.select(col($(predictionCol)),
-      vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata))
+    val df = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
+      dataset.select(col($(predictionCol)),
+        vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
+        lit(1.0).as(weightColName))
+    } else {
+      dataset.select(col($(predictionCol)),
+        vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
+        col(weightColName).cast(DoubleType))
+    }
+
     val metrics = new ClusteringMetrics(df)
     metrics.setDistanceMeasure($(distanceMeasure))
     metrics
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
index 3097033..8bf4ee1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
@@ -47,9 +47,9 @@ class ClusteringMetrics private[spark](dataset: Dataset[_]) {
     val columns = dataset.columns.toSeq
     if (distanceMeasure.equalsIgnoreCase("squaredEuclidean")) {
       SquaredEuclideanSilhouette.computeSilhouetteScore(
-        dataset, columns(0), columns(1))
+        dataset, columns(0), columns(1), columns(2))
     } else {
-      CosineSilhouette.computeSilhouetteScore(dataset, columns(0), columns(1))
+      CosineSilhouette.computeSilhouetteScore(dataset, columns(0), columns(1), columns(2))
     }
   }
 }
@@ -63,9 +63,10 @@ private[evaluation] abstract class Silhouette {
   def pointSilhouetteCoefficient(
       clusterIds: Set[Double],
       pointClusterId: Double,
-      pointClusterNumOfPoints: Long,
+      weightSum: Double,
+      weight: Double,
       averageDistanceToCluster: (Double) => Double): Double = {
-    if (pointClusterNumOfPoints == 1) {
+    if (weightSum == weight) {
       // Single-element clusters have silhouette 0
       0.0
     } else {
@@ -77,8 +78,8 @@ private[evaluation] abstract class Silhouette {
       val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min
       // adjustment for excluding the node itself from the computation of the average dissimilarity
       val currentClusterDissimilarity =
-      averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints /
-        (pointClusterNumOfPoints - 1)
+      averageDistanceToCluster(pointClusterId) * weightSum /
+        (weightSum - weight)
       if (currentClusterDissimilarity < neighboringClusterDissimilarity) {
         1 - (currentClusterDissimilarity / neighboringClusterDissimilarity)
       } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) {
@@ -92,8 +93,8 @@ private[evaluation] abstract class Silhouette {
   /**
    * Compute the mean Silhouette values of all samples.
    */
-  def overallScore(df: DataFrame, scoreColumn: Column): Double = {
-    df.select(avg(scoreColumn)).collect()(0).getDouble(0)
+  def overallScore(df: DataFrame, scoreColumn: Column, weightColumn: Column): Double = {
+    df.select(sum(scoreColumn * weightColumn) / sum(weightColumn)).collect()(0).getDouble(0)
   }
 }
 
@@ -267,7 +268,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
     }
   }
 
-  case class ClusterStats(featureSum: Vector, squaredNormSum: Double, numOfPoints: Long)
+  case class ClusterStats(featureSum: Vector, squaredNormSum: Double, weightSum: Double)
 
   /**
    * The method takes the input dataset and computes the aggregated values
@@ -277,6 +278,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
    * @param predictionCol The name of the column which contains the predicted cluster id
    *                      for the point.
    * @param featuresCol The name of the column which contains the feature vector of the point.
+   * @param weightCol The name of the column which contains the instance weight.
    * @return A [[scala.collection.immutable.Map]] which associates each cluster id
    *         to a [[ClusterStats]] object (which contains the precomputed values `N`,
    *         `$\Psi_{\Gamma}$` and `$Y_{\Gamma}$` for a cluster).
@@ -284,36 +286,39 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
   def computeClusterStats(
       df: DataFrame,
       predictionCol: String,
-      featuresCol: String): Map[Double, ClusterStats] = {
+      featuresCol: String,
+      weightCol: String): Map[Double, ClusterStats] = {
     val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
     val clustersStatsRDD = df.select(
-      col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"))
+      col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"), col(weightCol))
       .rdd
-      .map { row => (row.getDouble(0), (row.getAs[Vector](1), row.getDouble(2))) }
-      .aggregateByKey[(DenseVector, Double, Long)]((Vectors.zeros(numFeatures).toDense, 0.0, 0L))(
-      seqOp = {
-        case (
-          (featureSum: DenseVector, squaredNormSum: Double, numOfPoints: Long),
-          (features, squaredNorm)
-          ) =>
-          BLAS.axpy(1.0, features, featureSum)
-          (featureSum, squaredNormSum + squaredNorm, numOfPoints + 1)
-      },
-      combOp = {
-        case (
-          (featureSum1, squaredNormSum1, numOfPoints1),
-          (featureSum2, squaredNormSum2, numOfPoints2)
-          ) =>
-          BLAS.axpy(1.0, featureSum2, featureSum1)
-          (featureSum1, squaredNormSum1 + squaredNormSum2, numOfPoints1 + numOfPoints2)
-      }
-    )
+      .map { row => (row.getDouble(0), (row.getAs[Vector](1), row.getDouble(2), row.getDouble(3))) }
+      .aggregateByKey
+      [(DenseVector, Double, Double)]((Vectors.zeros(numFeatures).toDense, 0.0, 0.0))(
+        seqOp = {
+          case (
+            (featureSum: DenseVector, squaredNormSum: Double, weightSum: Double),
+            (features, squaredNorm, weight)
+            ) =>
+            require(weight >= 0.0, s"illegal weight value: $weight.  weight must be >= 0.0.")
+            BLAS.axpy(weight, features, featureSum)
+            (featureSum, squaredNormSum + squaredNorm * weight, weightSum + weight)
+        },
+        combOp = {
+          case (
+            (featureSum1, squaredNormSum1, weightSum1),
+            (featureSum2, squaredNormSum2, weightSum2)
+            ) =>
+            BLAS.axpy(1.0, featureSum2, featureSum1)
+            (featureSum1, squaredNormSum1 + squaredNormSum2, weightSum1 + weightSum2)
+        }
+      )
 
     clustersStatsRDD
       .collectAsMap()
       .mapValues {
-        case (featureSum: DenseVector, squaredNormSum: Double, numOfPoints: Long) =>
-          SquaredEuclideanSilhouette.ClusterStats(featureSum, squaredNormSum, numOfPoints)
+        case (featureSum: DenseVector, squaredNormSum: Double, weightSum: Double) =>
+          SquaredEuclideanSilhouette.ClusterStats(featureSum, squaredNormSum, weightSum)
       }
       .toMap
   }
@@ -324,6 +329,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
    * @param broadcastedClustersMap A map of the precomputed values for each cluster.
    * @param point The [[org.apache.spark.ml.linalg.Vector]] representing the current point.
    * @param clusterId The id of the cluster the current point belongs to.
+   * @param weight The instance weight of the current point.
    * @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point.
    * @return The Silhouette for the point.
    */
@@ -331,6 +337,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
       broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]],
       point: Vector,
       clusterId: Double,
+      weight: Double,
       squaredNorm: Double): Double = {
 
     def compute(targetClusterId: Double): Double = {
@@ -338,13 +345,14 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
       val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum)
 
       squaredNorm +
-        clusterStats.squaredNormSum / clusterStats.numOfPoints -
-        2 * pointDotClusterFeaturesSum / clusterStats.numOfPoints
+        clusterStats.squaredNormSum / clusterStats.weightSum -
+        2 * pointDotClusterFeaturesSum / clusterStats.weightSum
     }
 
     pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet,
       clusterId,
-      broadcastedClustersMap.value(clusterId).numOfPoints,
+      broadcastedClustersMap.value(clusterId).weightSum,
+      weight,
       compute)
   }
 
@@ -355,12 +363,14 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
    * @param predictionCol The name of the column which contains the predicted cluster id
    *                      for the point.
    * @param featuresCol The name of the column which contains the feature vector of the point.
+   * @param weightCol The name of the column which contains instance weight.
    * @return The average of the Silhouette values of the clustered data.
    */
   def computeSilhouetteScore(
       dataset: Dataset[_],
       predictionCol: String,
-      featuresCol: String): Double = {
+      featuresCol: String,
+      weightCol: String): Double = {
     SquaredEuclideanSilhouette.registerKryoClasses(dataset.sparkSession.sparkContext)
 
     val squaredNormUDF = udf {
@@ -370,7 +380,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
 
     // compute aggregate values for clusters needed by the algorithm
     val clustersStatsMap = SquaredEuclideanSilhouette
-      .computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol)
+      .computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol, weightCol)
 
     // Silhouette is reasonable only when the number of clusters is greater then 1
     assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.")
@@ -378,12 +388,12 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
     val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap)
 
     val computeSilhouetteCoefficientUDF = udf {
-      computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double)
+      computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double, _: Double)
     }
 
     val silhouetteScore = overallScore(dfWithSquaredNorm,
       computeSilhouetteCoefficientUDF(col(featuresCol), col(predictionCol).cast(DoubleType),
-        col("squaredNorm")))
+        col(weightCol), col("squaredNorm")), col(weightCol))
 
     bClustersStatsMap.destroy()
 
@@ -472,30 +482,35 @@ private[evaluation] object CosineSilhouette extends Silhouette {
    * about a cluster which are needed by the algorithm.
    *
    * @param df The DataFrame which contains the input data
+   * @param featuresCol The name of the column which contains the feature vector of the point.
    * @param predictionCol The name of the column which contains the predicted cluster id
    *                      for the point.
+   * @param weightCol The name of the column which contains the instance weight.
    * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a
    *         its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`).
    */
   def computeClusterStats(
       df: DataFrame,
       featuresCol: String,
-      predictionCol: String): Map[Double, (Vector, Long)] = {
+      predictionCol: String,
+      weightCol: String): Map[Double, (Vector, Double)] = {
     val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
     val clustersStatsRDD = df.select(
-      col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName))
+      col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName), col(weightCol))
       .rdd
-      .map { row => (row.getDouble(0), row.getAs[Vector](1)) }
-      .aggregateByKey[(DenseVector, Long)]((Vectors.zeros(numFeatures).toDense, 0L))(
+      .map { row => (row.getDouble(0), (row.getAs[Vector](1), row.getDouble(2))) }
+      .aggregateByKey[(DenseVector, Double)]((Vectors.zeros(numFeatures).toDense, 0.0))(
       seqOp = {
-        case ((normalizedFeaturesSum: DenseVector, numOfPoints: Long), (normalizedFeatures)) =>
-          BLAS.axpy(1.0, normalizedFeatures, normalizedFeaturesSum)
-          (normalizedFeaturesSum, numOfPoints + 1)
+        case ((normalizedFeaturesSum: DenseVector, weightSum: Double),
+        (normalizedFeatures, weight)) =>
+          require(weight >= 0.0, s"illegal weight value: $weight.  weight must be >= 0.0.")
+          BLAS.axpy(weight, normalizedFeatures, normalizedFeaturesSum)
+          (normalizedFeaturesSum, weightSum + weight)
       },
       combOp = {
-        case ((normalizedFeaturesSum1, numOfPoints1), (normalizedFeaturesSum2, numOfPoints2)) =>
+        case ((normalizedFeaturesSum1, weightSum1), (normalizedFeaturesSum2, weightSum2)) =>
           BLAS.axpy(1.0, normalizedFeaturesSum2, normalizedFeaturesSum1)
-          (normalizedFeaturesSum1, numOfPoints1 + numOfPoints2)
+          (normalizedFeaturesSum1, weightSum1 + weightSum2)
       }
     )
 
@@ -511,11 +526,13 @@ private[evaluation] object CosineSilhouette extends Silhouette {
    * @param normalizedFeatures The [[org.apache.spark.ml.linalg.Vector]] representing the
    *                           normalized features of the current point.
    * @param clusterId The id of the cluster the current point belongs to.
+   * @param weight The instance weight of the current point.
    */
   def computeSilhouetteCoefficient(
-      broadcastedClustersMap: Broadcast[Map[Double, (Vector, Long)]],
+      broadcastedClustersMap: Broadcast[Map[Double, (Vector, Double)]],
       normalizedFeatures: Vector,
-      clusterId: Double): Double = {
+      clusterId: Double,
+      weight: Double): Double = {
 
     def compute(targetClusterId: Double): Double = {
       val (normalizedFeatureSum, numOfPoints) = broadcastedClustersMap.value(targetClusterId)
@@ -525,6 +542,7 @@ private[evaluation] object CosineSilhouette extends Silhouette {
     pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet,
       clusterId,
       broadcastedClustersMap.value(clusterId)._2,
+      weight,
       compute)
   }
 
@@ -535,12 +553,14 @@ private[evaluation] object CosineSilhouette extends Silhouette {
    * @param predictionCol The name of the column which contains the predicted cluster id
    *                      for the point.
    * @param featuresCol The name of the column which contains the feature vector of the point.
+   * @param weightCol The name of the column which contains the instance weight.
    * @return The average of the Silhouette values of the clustered data.
    */
   def computeSilhouetteScore(
       dataset: Dataset[_],
       predictionCol: String,
-      featuresCol: String): Double = {
+      featuresCol: String,
+      weightCol: String): Double = {
     val normalizeFeatureUDF = udf {
       features: Vector => {
         val norm = Vectors.norm(features, 2.0)
@@ -553,7 +573,7 @@ private[evaluation] object CosineSilhouette extends Silhouette {
 
     // compute aggregate values for clusters needed by the algorithm
     val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol,
-      predictionCol)
+      predictionCol, weightCol)
 
     // Silhouette is reasonable only when the number of clusters is greater then 1
     assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.")
@@ -561,12 +581,12 @@ private[evaluation] object CosineSilhouette extends Silhouette {
     val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap)
 
     val computeSilhouetteCoefficientUDF = udf {
-      computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double)
+      computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double)
     }
 
     val silhouetteScore = overallScore(dfWithNormalizedFeatures,
       computeSilhouetteCoefficientUDF(col(normalizedFeaturesColName),
-        col(predictionCol).cast(DoubleType)))
+        col(predictionCol).cast(DoubleType), col(weightCol)), col(weightCol))
 
     bClustersStatsMap.destroy()
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
index 29fed53..d4c620a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
@@ -19,12 +19,13 @@ package org.apache.spark.ml.evaluation
 
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.lit
 
 
 class ClusteringEvaluatorSuite
@@ -161,4 +162,44 @@ class ClusteringEvaluatorSuite
 
     assert(evaluator.evaluate(irisDataset) == silhouetteScoreCosin)
   }
+
+  test("test weight support") {
+    Seq("squaredEuclidean", "cosine").foreach { distanceMeasure =>
+      val evaluator1 = new ClusteringEvaluator()
+        .setFeaturesCol("features")
+        .setPredictionCol("label")
+        .setDistanceMeasure(distanceMeasure)
+
+      val evaluator2 = new ClusteringEvaluator()
+        .setFeaturesCol("features")
+        .setPredictionCol("label")
+        .setDistanceMeasure(distanceMeasure)
+        .setWeightCol("weight")
+
+      Seq(0.25, 1.0, 10.0, 99.99).foreach { w =>
+        var score1 = evaluator1.evaluate(irisDataset)
+        var score2 = evaluator2.evaluate(irisDataset.withColumn("weight", lit(w)))
+        assert(score1 ~== score2 relTol 1e-6)
+
+        score1 = evaluator1.evaluate(newIrisDataset)
+        score2 = evaluator2.evaluate(newIrisDataset.withColumn("weight", lit(w)))
+        assert(score1 ~== score2 relTol 1e-6)
+      }
+    }
+  }
+
+  test("single-element clusters with weight") {
+    val singleItemClusters = spark.createDataFrame(spark.sparkContext.parallelize(Array(
+      (0.0, Vectors.dense(5.1, 3.5, 1.4, 0.2), 6.0),
+      (1.0, Vectors.dense(7.0, 3.2, 4.7, 1.4), 0.25),
+      (2.0, Vectors.dense(6.3, 3.3, 6.0, 2.5), 9.99)))).toDF("label", "features", "weight")
+    Seq("squaredEuclidean", "cosine").foreach { distanceMeasure =>
+      val evaluator = new ClusteringEvaluator()
+        .setFeaturesCol("features")
+        .setPredictionCol("label")
+        .setDistanceMeasure(distanceMeasure)
+        .setWeightCol("weight")
+      assert(evaluator.evaluate(singleItemClusters) === 0.0)
+    }
+  }
 }
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 265f02c..a69a57f58 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -654,7 +654,7 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
 
 
 @inherit_doc
-class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
+class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWeightCol,
                           JavaMLReadable, JavaMLWritable):
     """
     Evaluator for Clustering results, which expects two input
@@ -677,6 +677,18 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
     ClusteringEvaluator...
     >>> evaluator.evaluate(dataset)
     0.9079...
+    >>> featureAndPredictionsWithWeight = map(lambda x: (Vectors.dense(x[0]), x[1], x[2]),
+    ...     [([0.0, 0.5], 0.0, 2.5), ([0.5, 0.0], 0.0, 2.5), ([10.0, 11.0], 1.0, 2.5),
+    ...     ([10.5, 11.5], 1.0, 2.5), ([1.0, 1.0], 0.0, 2.5), ([8.0, 6.0], 1.0, 2.5)])
+    >>> dataset = spark.createDataFrame(
+    ...     featureAndPredictionsWithWeight, ["features", "prediction", "weight"])
+    >>> evaluator = ClusteringEvaluator()
+    >>> evaluator.setPredictionCol("prediction")
+    ClusteringEvaluator...
+    >>> evaluator.setWeightCol("weight")
+    ClusteringEvaluator...
+    >>> evaluator.evaluate(dataset)
+    0.9079...
     >>> ce_path = temp_path + "/ce"
     >>> evaluator.save(ce_path)
     >>> evaluator2 = ClusteringEvaluator.load(ce_path)
@@ -694,10 +706,10 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
 
     @keyword_only
     def __init__(self, predictionCol="prediction", featuresCol="features",
-                 metricName="silhouette", distanceMeasure="squaredEuclidean"):
+                 metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None):
         """
         __init__(self, predictionCol="prediction", featuresCol="features", \
-                 metricName="silhouette", distanceMeasure="squaredEuclidean")
+                 metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None)
         """
         super(ClusteringEvaluator, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -709,10 +721,10 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
     @keyword_only
     @since("2.3.0")
     def setParams(self, predictionCol="prediction", featuresCol="features",
-                  metricName="silhouette", distanceMeasure="squaredEuclidean"):
+                  metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None):
         """
         setParams(self, predictionCol="prediction", featuresCol="features", \
-                  metricName="silhouette", distanceMeasure="squaredEuclidean")
+                  metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None)
         Sets params for clustering evaluator.
         """
         kwargs = self._input_kwargs
@@ -758,6 +770,13 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
         """
         return self._set(predictionCol=value)
 
+    @since("3.1.0")
+    def setWeightCol(self, value):
+        """
+        Sets the value of :py:attr:`weightCol`.
+        """
+        return self._set(weightCol=value)
+
 
 @inherit_doc
 class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org