You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2018/04/23 20:23:06 UTC

spark git commit: [SPARK-11237][ML] Add pmml export for k-means in Spark ML

Repository: spark
Updated Branches:
  refs/heads/master 770add81c -> e82cb6834


[SPARK-11237][ML] Add pmml export for k-means in Spark ML

## What changes were proposed in this pull request?

Adding PMML export to Spark ML's KMeans Model.

## How was this patch tested?

New unit test for Spark ML PMML export based on the old Spark MLlib unit test.

Author: Holden Karau <ho...@pigscanfly.ca>

Closes #20907 from holdenk/SPARK-11237-Add-PMML-Export-for-KMeans.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e82cb683
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e82cb683
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e82cb683

Branch: refs/heads/master
Commit: e82cb68349b785c1b35bcfb85bff3a8ec2c93fee
Parents: 770add8
Author: Holden Karau <ho...@pigscanfly.ca>
Authored: Mon Apr 23 13:23:02 2018 -0700
Committer: DB Tsai <d_...@apple.com>
Committed: Mon Apr 23 13:23:02 2018 -0700

----------------------------------------------------------------------
 .../org.apache.spark.ml.util.MLFormatRegister   |  4 +-
 .../org/apache/spark/ml/clustering/KMeans.scala | 75 +++++++++++++-------
 .../spark/ml/regression/LinearRegression.scala  |  2 +-
 .../spark/ml/clustering/KMeansSuite.scala       | 32 ++++++++-
 4 files changed, 83 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e82cb683/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
----------------------------------------------------------------------
diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
index 5e5484f..f14431d 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
@@ -1,2 +1,4 @@
 org.apache.spark.ml.regression.InternalLinearRegressionModelWriter
-org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
\ No newline at end of file
+org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
+org.apache.spark.ml.clustering.InternalKMeansModelWriter
+org.apache.spark.ml.clustering.PMMLKMeansModelWriter
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/e82cb683/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 987a428..1ad157a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -17,11 +17,13 @@
 
 package org.apache.spark.ml.clustering
 
+import scala.collection.mutable
+
 import org.apache.hadoop.fs.Path
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.{Estimator, Model, PipelineStage}
 import org.apache.spark.ml.linalg.{Vector, VectorUDT}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
@@ -30,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
 import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
 import org.apache.spark.mllib.linalg.VectorImplicits._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{IntegerType, StructType}
 import org.apache.spark.storage.StorageLevel
@@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
 @Since("1.5.0")
 class KMeansModel private[ml] (
     @Since("1.5.0") override val uid: String,
-    private val parentModel: MLlibKMeansModel)
-  extends Model[KMeansModel] with KMeansParams with MLWritable {
+    private[clustering] val parentModel: MLlibKMeansModel)
+  extends Model[KMeansModel] with KMeansParams with GeneralMLWritable {
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): KMeansModel = {
@@ -152,14 +154,14 @@ class KMeansModel private[ml] (
   }
 
   /**
-   * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
+   * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
    *
    * For [[KMeansModel]], this does NOT currently save the training [[summary]].
    * An option to save [[summary]] may be added in the future.
    *
    */
   @Since("1.6.0")
-  override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
+  override def write: GeneralMLWriter = new GeneralMLWriter(this)
 
   private var trainingSummary: Option[KMeansSummary] = None
 
@@ -185,6 +187,47 @@ class KMeansModel private[ml] (
   }
 }
 
+/** Helper class for storing model data */
+private case class ClusterData(clusterIdx: Int, clusterCenter: Vector)
+
+
+/** A writer for KMeans that handles the "internal" (or default) format */
+private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
+
+  override def format(): String = "internal"
+  override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"
+
+  override def write(path: String, sparkSession: SparkSession,
+    optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+    val instance = stage.asInstanceOf[KMeansModel]
+    val sc = sparkSession.sparkContext
+    // Save metadata and Params
+    DefaultParamsWriter.saveMetadata(instance, path, sc)
+    // Save model data: cluster centers
+    val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map {
+      case (center, idx) =>
+        ClusterData(idx, center)
+    }
+    val dataPath = new Path(path, "data").toString
+    sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
+  }
+}
+
+/** A writer for KMeans that handles the "pmml" format */
+private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
+
+  override def format(): String = "pmml"
+  override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"
+
+  override def write(path: String, sparkSession: SparkSession,
+    optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+    val instance = stage.asInstanceOf[KMeansModel]
+    val sc = sparkSession.sparkContext
+    instance.parentModel.toPMML(sc, path)
+  }
+}
+
+
 @Since("1.6.0")
 object KMeansModel extends MLReadable[KMeansModel] {
 
@@ -194,30 +237,12 @@ object KMeansModel extends MLReadable[KMeansModel] {
   @Since("1.6.0")
   override def load(path: String): KMeansModel = super.load(path)
 
-  /** Helper class for storing model data */
-  private case class Data(clusterIdx: Int, clusterCenter: Vector)
-
   /**
    * We store all cluster centers in a single row and use this class to store model data by
    * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility.
    */
   private case class OldData(clusterCenters: Array[OldVector])
 
-  /** [[MLWriter]] instance for [[KMeansModel]] */
-  private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
-
-    override protected def saveImpl(path: String): Unit = {
-      // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
-      // Save model data: cluster centers
-      val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) =>
-        Data(idx, center)
-      }
-      val dataPath = new Path(path, "data").toString
-      sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
-    }
-  }
-
   private class KMeansModelReader extends MLReader[KMeansModel] {
 
     /** Checked against metadata when loading model */
@@ -232,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
       val dataPath = new Path(path, "data").toString
 
       val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
-        val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
+        val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData]
         data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
       } else {
         // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.

http://git-wip-us.apache.org/repos/asf/spark/blob/e82cb683/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index f67d9d8..9cdd3a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -746,7 +746,7 @@ private class InternalLinearRegressionModelWriter
 
 /** A writer for LinearRegression that handles the "pmml" format */
 private class PMMLLinearRegressionModelWriter
-    extends MLWriterFormat with MLFormatRegister {
+  extends MLWriterFormat with MLFormatRegister {
 
   override def format(): String = "pmml"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e82cb683/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 32830b3..77c9d48 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -19,17 +19,22 @@ package org.apache.spark.ml.clustering
 
 import scala.util.Random
 
+import org.dmg.pmml.{ClusteringModel, PMML}
+
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans}
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans,
+  KMeansModel => MLlibKMeansModel}
+import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 
 private[clustering] case class TestRow(features: Vector)
 
-class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
+  with PMMLReadWriteTest {
 
   final val k = 5
   @transient var dataset: Dataset[_] = _
@@ -202,6 +207,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
     testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
       KMeansSuite.allParamSettings, checkModelData)
   }
+
+  test("pmml export") {
+    val clusterCenters = Array(
+      MLlibVectors.dense(1.0, 2.0, 6.0),
+      MLlibVectors.dense(1.0, 3.0, 0.0),
+      MLlibVectors.dense(1.0, 4.0, 6.0))
+    val oldKmeansModel = new MLlibKMeansModel(clusterCenters)
+    val kmeansModel = new KMeansModel("", oldKmeansModel)
+    def checkModel(pmml: PMML): Unit = {
+      // Check the header descripiton is what we expect
+      assert(pmml.getHeader.getDescription === "k-means clustering")
+      // check that the number of fields match the single vector size
+      assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size)
+      // This verify that there is a model attached to the pmml object and the model is a clustering
+      // one. It also verifies that the pmml model has the same number of clusters of the spark
+      // model.
+      val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
+      assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
+    }
+    testPMMLWrite(sc, kmeansModel, checkModel)
+  }
 }
 
 object KMeansSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org