You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2018/07/04 14:56:31 UTC

spark git commit: [MINOR][ML] Minor correction in the powerIterationSuite

Repository: spark
Updated Branches:
  refs/heads/master 1a2655a9e -> ca8243f30


[MINOR][ML] Minor correction in the powerIterationSuite

## What changes were proposed in this pull request?

Currently the power iteration clustering test in  spark ml, maps the results to the labels 0 and 1 for assertion. Since the clustering outputs need not be the same as the mapped labels, it may cause failure in the test case. Even if it correctly maps, theoretically we cannot guarantee which set belongs to which cluster label. KMeans can assign label 0 to either of the set.

PowerIterationClusteringSuite in the MLLib checks the clustering results without mapping to the particular cluster label, as shown below.
``  val predictions = Array.fill(2)(mutable.Set.empty[Long])
    model.assignments.collect().foreach { a =>
      predictions(a.cluster) += a.id
    }
    assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet))
``

## How was this patch tested?
Existing tests

Author: Shahid <sh...@gmail.com>

Closes #21689 from shahidki31/picTestSuiteMinorCorrection.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca8243f3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca8243f3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca8243f3

Branch: refs/heads/master
Commit: ca8243f30fc6939ee099a9534e3b811d5c64d2cf
Parents: 1a2655a
Author: Shahid <sh...@gmail.com>
Authored: Wed Jul 4 09:56:24 2018 -0500
Committer: Sean Owen <sr...@gmail.com>
Committed: Wed Jul 4 09:56:24 2018 -0500

----------------------------------------------------------------------
 .../PowerIterationClusteringSuite.scala         | 30 +++++++++++++-------
 1 file changed, 20 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ca8243f3/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
index b707272..55b460f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.ml.clustering
 
+import scala.collection.mutable
+
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -76,12 +78,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite
       .setMaxIter(40)
       .setWeightCol("weight")
       .assignClusters(data)
-    val localAssignments = assignments
-      .select('id, 'cluster)
-      .as[(Long, Int)].collect().toSet
-    val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++
-      (n1 until n).map(x => (x, 0)).toSet
-    assert(localAssignments === expectedResult)
+      .select("id", "cluster")
+      .as[(Long, Int)]
+      .collect()
+
+    val predictions = Array.fill(2)(mutable.Set.empty[Long])
+    assignments.foreach {
+      case (id, cluster) => predictions(cluster) += id
+    }
+    assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet))
 
     val assignments2 = new PowerIterationClustering()
       .setK(2)
@@ -89,10 +94,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite
       .setInitMode("degree")
       .setWeightCol("weight")
       .assignClusters(data)
-    val localAssignments2 = assignments2
-      .select('id, 'cluster)
-      .as[(Long, Int)].collect().toSet
-    assert(localAssignments2 === expectedResult)
+      .select("id", "cluster")
+      .as[(Long, Int)]
+      .collect()
+
+    val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
+    assignments2.foreach {
+      case (id, cluster) => predictions2(cluster) += id
+    }
+    assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet))
   }
 
   test("supported input types") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org