You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/04/13 20:53:25 UTC

spark git commit: [SPARK-5988][MLlib] add save/load for PowerIterationClusteringModel

Repository: spark
Updated Branches:
  refs/heads/master 6cc5b3ed3 -> 1e340c3ae


[SPARK-5988][MLlib] add save/load for PowerIterationClusteringModel

See JIRA issue [SPARK-5988](https://issues.apache.org/jira/browse/SPARK-5988).

Author: Xusen Yin <yi...@gmail.com>

Closes #5450 from yinxusen/SPARK-5988 and squashes the following commits:

cb1ecfa [Xusen Yin] change Assignment into case class
b1dd24c [Xusen Yin] add test suite
63c3923 [Xusen Yin] add save load for power iteration clustering


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1e340c3a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1e340c3a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1e340c3a

Branch: refs/heads/master
Commit: 1e340c3ae4d5361d048a3d6990f144cfc923666f
Parents: 6cc5b3e
Author: Xusen Yin <yi...@gmail.com>
Authored: Mon Apr 13 11:53:17 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Apr 13 11:53:17 2015 -0700

----------------------------------------------------------------------
 .../clustering/PowerIterationClustering.scala   | 68 ++++++++++++++++++--
 .../PowerIterationClusteringSuite.scala         | 34 ++++++++++
 2 files changed, 97 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1e340c3a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 1800239..aa53e88 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -17,15 +17,20 @@
 
 package org.apache.spark.mllib.clustering
 
-import org.apache.spark.{Logging, SparkException}
+import org.json4s.JsonDSL._
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.graphx._
 import org.apache.spark.graphx.impl.GraphImpl
 import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
 import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.{Logging, SparkContext, SparkException}
 
 /**
  * :: Experimental ::
@@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom
 @Experimental
 class PowerIterationClusteringModel(
     val k: Int,
-    val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable
+    val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable {
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] {
+  override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
+    PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path)
+  }
+
+  private[clustering]
+  object SaveLoadV1_0 {
+
+    private val thisFormatVersion = "1.0"
+
+    private[clustering]
+    val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel"
+
+    def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
+      val sqlContext = new SQLContext(sc)
+      import sqlContext.implicits._
+
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+      val dataRDD = model.assignments.toDF()
+      dataRDD.saveAsParquetFile(Loader.dataPath(path))
+    }
+
+    def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
+      implicit val formats = DefaultFormats
+      val sqlContext = new SQLContext(sc)
+
+      val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+      assert(className == thisClassName)
+      assert(formatVersion == thisFormatVersion)
+
+      val k = (metadata \ "k").extract[Int]
+      val assignments = sqlContext.parquetFile(Loader.dataPath(path))
+      Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema)
+
+      val assignmentsRDD = assignments.map {
+        case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster)
+      }
+
+      new PowerIterationClusteringModel(k, assignmentsRDD)
+    }
+  }
+}
 
 /**
  * :: Experimental ::
@@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] (
     val v = powerIter(w, maxIterations)
     val assignments = kMeans(v, k).mapPartitions({ iter =>
       iter.map { case (id, cluster) =>
-        new Assignment(id, cluster)
+        Assignment(id, cluster)
       }
     }, preservesPartitioning = true)
     new PowerIterationClusteringModel(k, assignments)
@@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging {
    * @param cluster assigned cluster id
    */
   @Experimental
-  class Assignment(val id: Long, val cluster: Int) extends Serializable
+  case class Assignment(id: Long, cluster: Int)
 
   /**
    * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).

http://git-wip-us.apache.org/repos/asf/spark/blob/1e340c3a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 6315c03..6d6fe6f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -18,12 +18,15 @@
 package org.apache.spark.mllib.clustering
 
 import scala.collection.mutable
+import scala.util.Random
 
 import org.scalatest.FunSuite
 
+import org.apache.spark.SparkContext
 import org.apache.spark.graphx.{Edge, Graph}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
 
 class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
 
@@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
       assert(x ~== u1(i.toInt) absTol 1e-14)
     }
   }
+
+  test("model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    val model = PowerIterationClusteringSuite.createModel(sc, 3, 10)
+    try {
+      model.save(sc, path)
+      val sameModel = PowerIterationClusteringModel.load(sc, path)
+      PowerIterationClusteringSuite.checkEqual(model, sameModel)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+}
+
+object PowerIterationClusteringSuite extends FunSuite {
+  def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
+    val assignments = sc.parallelize(
+      (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
+    new PowerIterationClusteringModel(k, assignments)
+  }
+
+  def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = {
+    assert(a.k === b.k)
+
+    val aAssignments = a.assignments.map(x => (x.id, x.cluster))
+    val bAssignments = b.assignments.map(x => (x.id, x.cluster))
+    val unequalElements = aAssignments.join(bAssignments).filter {
+      case (id, (c1, c2)) => c1 != c2 }.count()
+    assert(unequalElements === 0L)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org