You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/03/31 20:12:52 UTC

spark git commit: [SPARK-13782][ML] Model export/import for spark.ml: BisectingKMeans

Repository: spark
Updated Branches:
  refs/heads/master 3b3cc7600 -> a0a199158


[SPARK-13782][ML] Model export/import for spark.ml: BisectingKMeans

## What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-13782
Model export/import for BisectingKMeans in spark.ml and mllib

## How was this patch tested?

unit tests

Author: Yuhao Yang <hh...@gmail.com>

Closes #11933 from hhbyyh/bisectingsave.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a0a19915
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a0a19915
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a0a19915

Branch: refs/heads/master
Commit: a0a1991580ed24230f88cae9f5a4dfbe58f03b28
Parents: 3b3cc76
Author: Yuhao Yang <hh...@gmail.com>
Authored: Thu Mar 31 11:12:40 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Mar 31 11:12:40 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/clustering/BisectingKMeans.scala   | 59 ++++++++++--
 .../mllib/clustering/BisectingKMeans.scala      |  2 +-
 .../mllib/clustering/BisectingKMeansModel.scala | 98 +++++++++++++++++++-
 .../ml/clustering/BisectingKMeansSuite.scala    | 22 ++++-
 .../mllib/clustering/BisectingKMeansSuite.scala | 18 ++++
 5 files changed, 190 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a0a19915/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index f014a1d..55f751c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -17,11 +17,13 @@
 
 package org.apache.spark.ml.clustering
 
+import org.apache.hadoop.fs.Path
+
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.clustering.
   {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -49,7 +51,7 @@ private[clustering] trait BisectingKMeansParams extends Params
 
   /** @group expertParam */
   @Since("2.0.0")
-  final val minDivisibleClusterSize = new Param[Double](
+  final val minDivisibleClusterSize = new DoubleParam(
     this,
     "minDivisibleClusterSize",
     "the minimum number of points (if >= 1.0) or the minimum proportion",
@@ -81,7 +83,7 @@ private[clustering] trait BisectingKMeansParams extends Params
 class BisectingKMeansModel private[ml] (
     @Since("2.0.0") override val uid: String,
     private val parentModel: MLlibBisectingKMeansModel
-  ) extends Model[BisectingKMeansModel] with BisectingKMeansParams {
+  ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable {
 
   @Since("2.0.0")
   override def copy(extra: ParamMap): BisectingKMeansModel = {
@@ -115,6 +117,44 @@ class BisectingKMeansModel private[ml] (
     val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
     parentModel.computeCost(data)
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)
+}
+
+object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
+  @Since("2.0.0")
+  override def read: MLReader[BisectingKMeansModel] = new BisectingKMeansModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): BisectingKMeansModel = super.load(path)
+
+  /** [[MLWriter]] instance for [[BisectingKMeansModel]] */
+  private[BisectingKMeansModel]
+  class BisectingKMeansModelWriter(instance: BisectingKMeansModel) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val dataPath = new Path(path, "data").toString
+      instance.parentModel.save(sc, dataPath)
+    }
+  }
+
+  private class BisectingKMeansModelReader extends MLReader[BisectingKMeansModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[BisectingKMeansModel].getName
+
+    override def load(path: String): BisectingKMeansModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
+      val model = new BisectingKMeansModel(metadata.uid, mllibModel)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 }
 
 /**
@@ -137,7 +177,7 @@ class BisectingKMeansModel private[ml] (
 @Experimental
 class BisectingKMeans @Since("2.0.0") (
     @Since("2.0.0") override val uid: String)
-  extends Estimator[BisectingKMeansModel] with BisectingKMeansParams {
+  extends Estimator[BisectingKMeansModel] with BisectingKMeansParams with DefaultParamsWritable {
 
   setDefault(
     k -> 4,
@@ -148,7 +188,7 @@ class BisectingKMeans @Since("2.0.0") (
   override def copy(extra: ParamMap): BisectingKMeans = defaultCopy(extra)
 
   @Since("2.0.0")
-  def this() = this(Identifiable.randomUID("bisecting k-means"))
+  def this() = this(Identifiable.randomUID("bisecting-kmeans"))
 
   /** @group setParam */
   @Since("2.0.0")
@@ -194,3 +234,10 @@ class BisectingKMeans @Since("2.0.0") (
   }
 }
 
+
+@Since("2.0.0")
+object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] {
+
+  @Since("2.0.0")
+  override def load(path: String): BisectingKMeans = super.load(path)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a0a19915/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index 64b838a..e4bd0dc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -411,7 +411,7 @@ private object BisectingKMeans extends Serializable {
 private[clustering] class ClusteringTreeNode private[clustering] (
     val index: Int,
     val size: Long,
-    private val centerWithNorm: VectorWithNorm,
+    private[clustering] val centerWithNorm: VectorWithNorm,
     val cost: Double,
     val height: Double,
     val children: Array[ClusteringTreeNode]) extends Serializable {

http://git-wip-us.apache.org/repos/asf/spark/blob/a0a19915/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
index 01a0d31..c3b5b8b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
@@ -17,11 +17,19 @@
 
 package org.apache.spark.mllib.clustering
 
+import org.json4s._
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods._
+import org.json4s.JsonDSL._
+
+import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.internal.Logging
 import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
 
 /**
  * Clustering model produced by [[BisectingKMeans]].
@@ -34,7 +42,7 @@ import org.apache.spark.rdd.RDD
 @Experimental
 class BisectingKMeansModel private[clustering] (
     private[clustering] val root: ClusteringTreeNode
-  ) extends Serializable with Logging {
+  ) extends Serializable with Saveable with Logging {
 
   /**
    * Leaf cluster centers.
@@ -92,4 +100,92 @@ class BisectingKMeansModel private[clustering] (
    */
   @Since("1.6.0")
   def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd)
+
+  @Since("2.0.0")
+  override def save(sc: SparkContext, path: String): Unit = {
+    BisectingKMeansModel.SaveLoadV1_0.save(sc, this, path)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+@Since("2.0.0")
+object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
+
+  @Since("2.0.0")
+  override def load(sc: SparkContext, path: String): BisectingKMeansModel = {
+    val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+    implicit val formats = DefaultFormats
+    val rootId = (metadata \ "rootId").extract[Int]
+    val classNameV1_0 = SaveLoadV1_0.thisClassName
+    (loadedClassName, formatVersion) match {
+      case (classNameV1_0, "1.0") =>
+        val model = SaveLoadV1_0.load(sc, path, rootId)
+        model
+      case _ => throw new Exception(
+        s"BisectingKMeansModel.load did not recognize model with (className, format version):" +
+          s"($loadedClassName, $formatVersion).  Supported:\n" +
+          s"  ($classNameV1_0, 1.0)")
+    }
+  }
+
+  private case class Data(index: Int, size: Long, center: Vector, norm: Double, cost: Double,
+     height: Double, children: Seq[Int])
+
+  private object Data {
+    def apply(r: Row): Data = Data(r.getInt(0), r.getLong(1), r.getAs[Vector](2), r.getDouble(3),
+      r.getDouble(4), r.getDouble(5), r.getSeq[Int](6))
+  }
+
+  private[clustering] object SaveLoadV1_0 {
+    private val thisFormatVersion = "1.0"
+
+    private[clustering]
+    val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
+
+    def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
+      val sqlContext = SQLContext.getOrCreate(sc)
+      import sqlContext.implicits._
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
+          ~ ("rootId" -> model.root.index)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+      val data = getNodes(model.root).map(node => Data(node.index, node.size,
+        node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height,
+        node.children.map(_.index)))
+      val dataRDD = sc.parallelize(data).toDF()
+      dataRDD.write.parquet(Loader.dataPath(path))
+    }
+
+    private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
+      if (node.children.isEmpty) {
+        Array(node)
+      } else {
+        node.children.flatMap(getNodes(_)) ++ Array(node)
+      }
+    }
+
+    def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
+      val sqlContext = SQLContext.getOrCreate(sc)
+      val rows = sqlContext.read.parquet(Loader.dataPath(path))
+      Loader.checkSchema[Data](rows.schema)
+      val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
+      val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
+      val rootNode = buildTree(rootId, nodes)
+      new BisectingKMeansModel(rootNode)
+    }
+
+    private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = {
+      val root = nodes.get(rootId).get
+      if (root.children.isEmpty) {
+        new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+          root.cost, root.height, new Array[ClusteringTreeNode](0))
+      } else {
+        val children = root.children.map(c => buildTree(c, nodes))
+        new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+          root.cost, root.height, children.toArray)
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a0a19915/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index b719a8c..18f2c99 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -18,10 +18,12 @@
 package org.apache.spark.ml.clustering
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.DataFrame
 
-class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+class BisectingKMeansSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   final val k = 5
   @transient var dataset: DataFrame = _
@@ -84,4 +86,22 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(model.computeCost(dataset) < 0.1)
     assert(model.hasParent)
   }
+
+  test("read/write") {
+    def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = {
+      assert(model.clusterCenters === model2.clusterCenters)
+    }
+    val bisectingKMeans = new BisectingKMeans()
+    testEstimatorAndModelReadWrite(
+      bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
+  }
+}
+
+object BisectingKMeansSuite {
+  val allParamSettings: Map[String, Any] = Map(
+    "k" -> 3,
+    "maxIter" -> 2,
+    "seed" -> -1L,
+    "minDivisibleClusterSize" -> 2.0
+  )
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a0a19915/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
index 41b9d5c..35f7932 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
 
 class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
 
@@ -179,4 +180,21 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
       }
     }
   }
+
+  test("BisectingKMeans model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    val points = (1 until 8).map(i => Vectors.dense(i))
+    val data = sc.parallelize(points, 2)
+    val model = new BisectingKMeans().run(data)
+    try {
+      model.save(sc, path)
+      val sameModel = BisectingKMeansModel.load(sc, path)
+      assert(model.k === sameModel.k)
+      model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org