You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/04/06 20:45:26 UTC

spark git commit: [SPARK-13538][ML] Add GaussianMixture to ML

Repository: spark
Updated Branches:
  refs/heads/master 8cffcb60d -> af73d9737


[SPARK-13538][ML] Add GaussianMixture to ML

JIRA: https://issues.apache.org/jira/browse/SPARK-13538

## What changes were proposed in this pull request?

Add GaussianMixture and GaussianMixtureModel to ML package

## How was this patch tested?

unit tests and manual tests were done.
Local Scalastyle checks passed.

Author: Zheng RuiFeng <ru...@foxmail.com>
Author: Ruifeng Zheng <ru...@foxmail.com>
Author: Joseph K. Bradley <jo...@databricks.com>

Closes #11419 from zhengruifeng/mlgmm.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/af73d973
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/af73d973
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/af73d973

Branch: refs/heads/master
Commit: af73d9737874f7adaec3cd19ac889ab3badb8e2a
Parents: 8cffcb6
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Wed Apr 6 11:45:16 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed Apr 6 11:45:16 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/clustering/GaussianMixture.scala   | 311 +++++++++++++++++++
 .../ml/clustering/GaussianMixtureSuite.scala    | 133 ++++++++
 2 files changed, 444 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/af73d973/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
new file mode 100644
index 0000000..120bf3c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+
+/**
+ * Common params for GaussianMixture and GaussianMixtureModel
+ */
+private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter with HasFeaturesCol
+  with HasSeed with HasPredictionCol with HasProbabilityCol with HasTol {
+
+  /**
+   * Set the number of clusters to create (k). Must be > 1. Default: 2.
+   * @group param
+   */
+  @Since("2.0.0")
+  final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1)
+
+  /** @group getParam */
+  @Since("2.0.0")
+  def getK: Int = $(k)
+
+  /**
+   * Validates and transforms the input schema.
+   * @param schema input schema
+   * @return output schema
+   */
+  protected def validateAndTransformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+    SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
+    SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
+  }
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by GaussianMixture.
+ * @param parentModel a model trained by spark.mllib.clustering.GaussianMixture.
+ */
+@Since("2.0.0")
+@Experimental
+class GaussianMixtureModel private[ml] (
+    @Since("2.0.0") override val uid: String,
+    private val parentModel: MLlibGMModel)
+  extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {
+
+  @Since("2.0.0")
+  override def copy(extra: ParamMap): GaussianMixtureModel = {
+    val copied = new GaussianMixtureModel(uid, parentModel)
+    copyValues(copied, extra).setParent(this.parent)
+  }
+
+  @Since("2.0.0")
+  override def transform(dataset: DataFrame): DataFrame = {
+    val predUDF = udf((vector: Vector) => predict(vector))
+    val probUDF = udf((vector: Vector) => predictProbability(vector))
+    dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
+      .withColumn($(probabilityCol), probUDF(col($(featuresCol))))
+  }
+
+  @Since("2.0.0")
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema)
+  }
+
+  private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+
+  private[clustering] def predictProbability(features: Vector): Vector = {
+    Vectors.dense(parentModel.predictSoft(features))
+  }
+
+  @Since("2.0.0")
+  def weights: Array[Double] = parentModel.weights
+
+  @Since("2.0.0")
+  def gaussians: Array[MultivariateGaussian] = parentModel.gaussians
+
+  @Since("2.0.0")
+  override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)
+
+  private var trainingSummary: Option[GaussianMixtureSummary] = None
+
+  private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = {
+    this.trainingSummary = Some(summary)
+    this
+  }
+
+  /**
+   * Return true if there exists summary of model.
+   */
+  @Since("2.0.0")
+  def hasSummary: Boolean = trainingSummary.nonEmpty
+
+  /**
+   * Gets summary of model on training set. An exception is
+   * thrown if `trainingSummary == None`.
+   */
+  @Since("2.0.0")
+  def summary: GaussianMixtureSummary = trainingSummary.getOrElse {
+    throw new RuntimeException(
+      s"No training summary available for the ${this.getClass.getSimpleName}")
+  }
+}
+
+@Since("2.0.0")
+object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
+
+  @Since("2.0.0")
+  override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): GaussianMixtureModel = super.load(path)
+
+  /** [[MLWriter]] instance for [[GaussianMixtureModel]] */
+  private[GaussianMixtureModel] class GaussianMixtureModelWriter(
+      instance: GaussianMixtureModel) extends MLWriter {
+
+    private case class Data(weights: Array[Double], mus: Array[Vector], sigmas: Array[Matrix])
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: weights and gaussians
+      val weights = instance.weights
+      val gaussians = instance.gaussians
+      val mus = gaussians.map(_.mu)
+      val sigmas = gaussians.map(_.sigma)
+      val data = Data(weights, mus, sigmas)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class GaussianMixtureModelReader extends MLReader[GaussianMixtureModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[GaussianMixtureModel].getName
+
+    override def load(path: String): GaussianMixtureModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
+      val weights = row.getSeq[Double](0).toArray
+      val mus = row.getSeq[Vector](1).toArray
+      val sigmas = row.getSeq[Matrix](2).toArray
+      require(mus.length == sigmas.length, "Length of Mu and Sigma array must match")
+      require(mus.length == weights.length, "Length of weight and Gaussian array must match")
+
+      val gaussians = (mus zip sigmas).map {
+        case (mu, sigma) =>
+          new MultivariateGaussian(mu, sigma)
+      }
+      val model = new GaussianMixtureModel(metadata.uid, new MLlibGMModel(weights, gaussians))
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+}
+
+/**
+ * :: Experimental ::
+ * GaussianMixture clustering.
+ */
+@Since("2.0.0")
+@Experimental
+class GaussianMixture @Since("2.0.0") (
+    @Since("2.0.0") override val uid: String)
+  extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with DefaultParamsWritable {
+
+  setDefault(
+    k -> 2,
+    maxIter -> 100,
+    tol -> 0.01)
+
+  @Since("2.0.0")
+  override def copy(extra: ParamMap): GaussianMixture = defaultCopy(extra)
+
+  @Since("2.0.0")
+  def this() = this(Identifiable.randomUID("GaussianMixture"))
+
+  /** @group setParam */
+  @Since("2.0.0")
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group setParam */
+  @Since("2.0.0")
+  def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+  /** @group setParam */
+  @Since("2.0.0")
+  def setProbabilityCol(value: String): this.type = set(probabilityCol, value)
+
+  /** @group setParam */
+  @Since("2.0.0")
+  def setK(value: Int): this.type = set(k, value)
+
+  /** @group setParam */
+  @Since("2.0.0")
+  def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+  /** @group setParam */
+  @Since("2.0.0")
+  def setTol(value: Double): this.type = set(tol, value)
+
+  /** @group setParam */
+  @Since("2.0.0")
+  def setSeed(value: Long): this.type = set(seed, value)
+
+  @Since("2.0.0")
+  override def fit(dataset: DataFrame): GaussianMixtureModel = {
+    val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
+
+    val algo = new MLlibGM()
+      .setK($(k))
+      .setMaxIterations($(maxIter))
+      .setSeed($(seed))
+      .setConvergenceTol($(tol))
+    val parentModel = algo.run(rdd)
+    val model = copyValues(new GaussianMixtureModel(uid, parentModel).setParent(this))
+    val summary = new GaussianMixtureSummary(model.transform(dataset),
+      $(predictionCol), $(probabilityCol), $(featuresCol), $(k))
+    model.setSummary(summary)
+  }
+
+  @Since("2.0.0")
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema)
+  }
+}
+
+@Since("2.0.0")
+object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
+
+  @Since("2.0.0")
+  override def load(path: String): GaussianMixture = super.load(path)
+}
+
+/**
+ * :: Experimental ::
+ * Summary of GaussianMixture.
+ *
+ * @param predictions  [[DataFrame]] produced by [[GaussianMixtureModel.transform()]]
+ * @param predictionCol  Name for column of predicted clusters in `predictions`
+ * @param probabilityCol  Name for column of predicted probability of each cluster in `predictions`
+ * @param featuresCol  Name for column of features in `predictions`
+ * @param k  Number of clusters
+ */
+@Since("2.0.0")
+@Experimental
+class GaussianMixtureSummary private[clustering] (
+    @Since("2.0.0") @transient val predictions: DataFrame,
+    @Since("2.0.0") val predictionCol: String,
+    @Since("2.0.0") val probabilityCol: String,
+    @Since("2.0.0") val featuresCol: String,
+    @Since("2.0.0") val k: Int) extends Serializable {
+
+  /**
+   * Cluster centers of the transformed data.
+   */
+  @Since("2.0.0")
+  @transient lazy val cluster: DataFrame = predictions.select(predictionCol)
+
+  /**
+   * Probability of each cluster.
+   */
+  @Since("2.0.0")
+  @transient lazy val probability: DataFrame = predictions.select(probabilityCol)
+
+  /**
+   * Size of (number of data points in) each cluster.
+   */
+  @Since("2.0.0")
+  lazy val clusterSizes: Array[Long] = {
+    val sizes = Array.fill[Long](k)(0)
+    cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
+      case Row(cluster: Int, count: Long) => sizes(cluster) = count
+    }
+    sizes
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/af73d973/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
new file mode 100644
index 0000000..8edd44e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.DataFrame
+
+
+class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
+  with DefaultReadWriteTest {
+
+  final val k = 5
+  @transient var dataset: DataFrame = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+
+    dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
+  }
+
+  test("default parameters") {
+    val gm = new GaussianMixture()
+
+    assert(gm.getK === 2)
+    assert(gm.getFeaturesCol === "features")
+    assert(gm.getPredictionCol === "prediction")
+    assert(gm.getMaxIter === 100)
+    assert(gm.getTol === 0.01)
+  }
+
+  test("set parameters") {
+    val gm = new GaussianMixture()
+      .setK(9)
+      .setFeaturesCol("test_feature")
+      .setPredictionCol("test_prediction")
+      .setProbabilityCol("test_probability")
+      .setMaxIter(33)
+      .setSeed(123)
+      .setTol(1e-3)
+
+    assert(gm.getK === 9)
+    assert(gm.getFeaturesCol === "test_feature")
+    assert(gm.getPredictionCol === "test_prediction")
+    assert(gm.getProbabilityCol === "test_probability")
+    assert(gm.getMaxIter === 33)
+    assert(gm.getSeed === 123)
+    assert(gm.getTol === 1e-3)
+  }
+
+  test("parameters validation") {
+    intercept[IllegalArgumentException] {
+      new GaussianMixture().setK(1)
+    }
+  }
+
+  test("fit, transform, and summary") {
+    val predictionColName = "gm_prediction"
+    val probabilityColName = "gm_probability"
+    val gm = new GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName)
+        .setProbabilityCol(probabilityColName).setSeed(1)
+    val model = gm.fit(dataset)
+    assert(model.hasParent)
+    assert(model.weights.length === k)
+    assert(model.gaussians.length === k)
+
+    val transformed = model.transform(dataset)
+    val expectedColumns = Array("features", predictionColName, probabilityColName)
+    expectedColumns.foreach { column =>
+      assert(transformed.columns.contains(column))
+    }
+
+    // Check validity of model summary
+    val numRows = dataset.count()
+    assert(model.hasSummary)
+    val summary: GaussianMixtureSummary = model.summary
+    assert(summary.predictionCol === predictionColName)
+    assert(summary.probabilityCol === probabilityColName)
+    assert(summary.featuresCol === "features")
+    assert(summary.predictions.count() === numRows)
+    for (c <- Array(predictionColName, probabilityColName, "features")) {
+      assert(summary.predictions.columns.contains(c))
+    }
+    assert(summary.cluster.columns === Array(predictionColName))
+    assert(summary.probability.columns === Array(probabilityColName))
+    val clusterSizes = summary.clusterSizes
+    assert(clusterSizes.length === k)
+    assert(clusterSizes.sum === numRows)
+    assert(clusterSizes.forall(_ >= 0))
+  }
+
+  test("read/write") {
+    def checkModelData(model: GaussianMixtureModel, model2: GaussianMixtureModel): Unit = {
+      assert(model.weights === model2.weights)
+      assert(model.gaussians.map(_.mu) === model2.gaussians.map(_.mu))
+      assert(model.gaussians.map(_.sigma) === model2.gaussians.map(_.sigma))
+    }
+    val gm = new GaussianMixture()
+    testEstimatorAndModelReadWrite(gm, dataset,
+      GaussianMixtureSuite.allParamSettings, checkModelData)
+  }
+}
+
+object GaussianMixtureSuite {
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "predictionCol" -> "myPrediction",
+    "probabilityCol" -> "myProbability",
+    "k" -> 3,
+    "maxIter" -> 2,
+    "tol" -> 0.01
+  )
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org