You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2018/04/23 20:23:06 UTC
spark git commit: [SPARK-11237][ML] Add pmml export for k-means in
Spark ML
Repository: spark
Updated Branches:
refs/heads/master 770add81c -> e82cb6834
[SPARK-11237][ML] Add pmml export for k-means in Spark ML
## What changes were proposed in this pull request?
Adding PMML export to Spark ML's KMeans Model.
## How was this patch tested?
New unit test for Spark ML PMML export based on the old Spark MLlib unit test.
Author: Holden Karau <ho...@pigscanfly.ca>
Closes #20907 from holdenk/SPARK-11237-Add-PMML-Export-for-KMeans.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e82cb683
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e82cb683
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e82cb683
Branch: refs/heads/master
Commit: e82cb68349b785c1b35bcfb85bff3a8ec2c93fee
Parents: 770add8
Author: Holden Karau <ho...@pigscanfly.ca>
Authored: Mon Apr 23 13:23:02 2018 -0700
Committer: DB Tsai <d_...@apple.com>
Committed: Mon Apr 23 13:23:02 2018 -0700
----------------------------------------------------------------------
.../org.apache.spark.ml.util.MLFormatRegister | 4 +-
.../org/apache/spark/ml/clustering/KMeans.scala | 75 +++++++++++++-------
.../spark/ml/regression/LinearRegression.scala | 2 +-
.../spark/ml/clustering/KMeansSuite.scala | 32 ++++++++-
4 files changed, 83 insertions(+), 30 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/e82cb683/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
----------------------------------------------------------------------
diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
index 5e5484f..f14431d 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
@@ -1,2 +1,4 @@
org.apache.spark.ml.regression.InternalLinearRegressionModelWriter
-org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
\ No newline at end of file
+org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
+org.apache.spark.ml.clustering.InternalKMeansModelWriter
+org.apache.spark.ml.clustering.PMMLKMeansModelWriter
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/spark/blob/e82cb683/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 987a428..1ad157a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -17,11 +17,13 @@
package org.apache.spark.ml.clustering
+import scala.collection.mutable
+
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -30,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
- private val parentModel: MLlibKMeansModel)
- extends Model[KMeansModel] with KMeansParams with MLWritable {
+ private[clustering] val parentModel: MLlibKMeansModel)
+ extends Model[KMeansModel] with KMeansParams with GeneralMLWritable {
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
@@ -152,14 +154,14 @@ class KMeansModel private[ml] (
}
/**
- * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
+ * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
*
* For [[KMeansModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
*
*/
@Since("1.6.0")
- override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
+ override def write: GeneralMLWriter = new GeneralMLWriter(this)
private var trainingSummary: Option[KMeansSummary] = None
@@ -185,6 +187,47 @@ class KMeansModel private[ml] (
}
}
+/** Helper class for storing model data */
+private case class ClusterData(clusterIdx: Int, clusterCenter: Vector)
+
+
+/** A writer for KMeans that handles the "internal" (or default) format */
+private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
+
+ override def format(): String = "internal"
+ override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"
+
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ val instance = stage.asInstanceOf[KMeansModel]
+ val sc = sparkSession.sparkContext
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: cluster centers
+ val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map {
+ case (center, idx) =>
+ ClusterData(idx, center)
+ }
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
+ }
+}
+
+/** A writer for KMeans that handles the "pmml" format */
+private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
+
+ override def format(): String = "pmml"
+ override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"
+
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ val instance = stage.asInstanceOf[KMeansModel]
+ val sc = sparkSession.sparkContext
+ instance.parentModel.toPMML(sc, path)
+ }
+}
+
+
@Since("1.6.0")
object KMeansModel extends MLReadable[KMeansModel] {
@@ -194,30 +237,12 @@ object KMeansModel extends MLReadable[KMeansModel] {
@Since("1.6.0")
override def load(path: String): KMeansModel = super.load(path)
- /** Helper class for storing model data */
- private case class Data(clusterIdx: Int, clusterCenter: Vector)
-
/**
* We store all cluster centers in a single row and use this class to store model data by
* Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility.
*/
private case class OldData(clusterCenters: Array[OldVector])
- /** [[MLWriter]] instance for [[KMeansModel]] */
- private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
-
- override protected def saveImpl(path: String): Unit = {
- // Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
- // Save model data: cluster centers
- val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) =>
- Data(idx, center)
- }
- val dataPath = new Path(path, "data").toString
- sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
- }
- }
-
private class KMeansModelReader extends MLReader[KMeansModel] {
/** Checked against metadata when loading model */
@@ -232,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
val dataPath = new Path(path, "data").toString
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
- val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
+ val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
http://git-wip-us.apache.org/repos/asf/spark/blob/e82cb683/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index f67d9d8..9cdd3a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -746,7 +746,7 @@ private class InternalLinearRegressionModelWriter
/** A writer for LinearRegression that handles the "pmml" format */
private class PMMLLinearRegressionModelWriter
- extends MLWriterFormat with MLFormatRegister {
+ extends MLWriterFormat with MLFormatRegister {
override def format(): String = "pmml"
http://git-wip-us.apache.org/repos/asf/spark/blob/e82cb683/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 32830b3..77c9d48 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -19,17 +19,22 @@ package org.apache.spark.ml.clustering
import scala.util.Random
+import org.dmg.pmml.{ClusteringModel, PMML}
+
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans}
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans,
+ KMeansModel => MLlibKMeansModel}
+import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
private[clustering] case class TestRow(features: Vector)
-class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
+ with PMMLReadWriteTest {
final val k = 5
@transient var dataset: Dataset[_] = _
@@ -202,6 +207,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
KMeansSuite.allParamSettings, checkModelData)
}
+
+ test("pmml export") {
+ val clusterCenters = Array(
+ MLlibVectors.dense(1.0, 2.0, 6.0),
+ MLlibVectors.dense(1.0, 3.0, 0.0),
+ MLlibVectors.dense(1.0, 4.0, 6.0))
+ val oldKmeansModel = new MLlibKMeansModel(clusterCenters)
+ val kmeansModel = new KMeansModel("", oldKmeansModel)
+ def checkModel(pmml: PMML): Unit = {
+ // Check the header descripiton is what we expect
+ assert(pmml.getHeader.getDescription === "k-means clustering")
+ // check that the number of fields match the single vector size
+ assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size)
+ // This verify that there is a model attached to the pmml object and the model is a clustering
+ // one. It also verifies that the pmml model has the same number of clusters of the spark
+ // model.
+ val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
+ assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
+ }
+ testPMMLWrite(sc, kmeansModel, checkModel)
+ }
}
object KMeansSuite {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org