You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2016/04/27 01:53:19 UTC

spark git commit: [SPARK-14732][ML] spark.ml GaussianMixture should use MultivariateGaussian in mllib-local

Repository: spark
Updated Branches:
  refs/heads/master 0c99c23b7 -> bd2c9a6d4


[SPARK-14732][ML] spark.ml GaussianMixture should use MultivariateGaussian in mllib-local

## What changes were proposed in this pull request?

Before, spark.ml GaussianMixtureModel used the spark.mllib MultivariateGaussian in its public API.  This was added after 1.6, so we can modify this API without breaking APIs.

This PR copies MultivariateGaussian to mllib-local in spark.ml, with a few changes:
* Renamed fields to match numpy, scipy: mu => mean, sigma => cov

This PR then uses the spark.ml MultivariateGaussian in the spark.ml GaussianMixtureModel, which involves:
* Modifying the constructor
* Adding a computeProbabilities method

Also:
* Added EPSILON to mllib-local for use in MultivariateGaussian

## How was this patch tested?

Existing unit tests

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #12593 from jkbradley/sparkml-gmm-fix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bd2c9a6d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bd2c9a6d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bd2c9a6d

Branch: refs/heads/master
Commit: bd2c9a6d48ef6d489c747d9db2642bdef6b1f728
Parents: 0c99c23
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Tue Apr 26 16:53:16 2016 -0700
Committer: DB Tsai <db...@netflix.com>
Committed: Tue Apr 26 16:53:16 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/impl/Utils.scala  |  30 +++++
 .../distribution/MultivariateGaussian.scala     | 131 +++++++++++++++++++
 .../org/apache/spark/ml/impl/UtilsSuite.scala   |  30 +++++
 .../MultivariateGaussianSuite.scala             |  83 ++++++++++++
 .../spark/ml/clustering/GaussianMixture.scala   | 108 ++++++++++-----
 .../ml/clustering/GaussianMixtureSuite.scala    |   4 +-
 python/pyspark/ml/clustering.py                 |  11 +-
 7 files changed, 353 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bd2c9a6d/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala b/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
new file mode 100644
index 0000000..112de98
--- /dev/null
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl
+
+
+private[ml] object Utils {
+
+  lazy val EPSILON = {
+    var eps = 1.0
+    while ((1.0 + (eps / 2.0)) != 1.0) {
+      eps /= 2.0
+    }
+    eps
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2c9a6d/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
new file mode 100644
index 0000000..c62a1ea
--- /dev/null
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat.distribution
+
+import breeze.linalg.{diag, eigSym, max, DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
+
+import org.apache.spark.ml.impl.Utils
+import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors}
+
+
+/**
+ * This class provides basic functionality for a Multivariate Gaussian (Normal) Distribution. In
+ * the event that the covariance matrix is singular, the density will be computed in a
+ * reduced dimensional subspace under which the distribution is supported.
+ * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]])
+ *
+ * @param mean The mean vector of the distribution
+ * @param cov The covariance matrix of the distribution
+ */
+class MultivariateGaussian(
+    val mean: Vector,
+    val cov: Matrix) extends Serializable {
+
+  require(cov.numCols == cov.numRows, "Covariance matrix must be square")
+  require(mean.size == cov.numCols, "Mean vector length must match covariance matrix size")
+
+  /** Private constructor taking Breeze types */
+  private[ml] def this(mean: BDV[Double], cov: BDM[Double]) = {
+    this(Vectors.fromBreeze(mean), Matrices.fromBreeze(cov))
+  }
+
+  private val breezeMu = mean.toBreeze.toDenseVector
+
+  /**
+   * Compute distribution dependent constants:
+   *    rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t
+   *    u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+   */
+  private val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants
+
+  /**
+   * Returns density of this multivariate Gaussian at given point, x
+   */
+  def pdf(x: Vector): Double = {
+    pdf(x.toBreeze)
+  }
+
+  /**
+   * Returns the log-density of this multivariate Gaussian at given point, x
+   */
+  def logpdf(x: Vector): Double = {
+    logpdf(x.toBreeze)
+  }
+
+  /** Returns density of this multivariate Gaussian at given point, x */
+  private[ml] def pdf(x: BV[Double]): Double = {
+    math.exp(logpdf(x))
+  }
+
+  /** Returns the log-density of this multivariate Gaussian at given point, x */
+  private[ml] def logpdf(x: BV[Double]): Double = {
+    val delta = x - breezeMu
+    val v = rootSigmaInv * delta
+    u + v.t * v * -0.5
+  }
+
+  /**
+   * Calculate distribution dependent components used for the density function:
+   *    pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))
+   * where k is length of the mean vector.
+   *
+   * We here compute distribution-fixed parts
+   *  log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+   * and
+   *  D^(-1/2)^ * U, where sigma = U * D * U.t
+   *
+   * Both the determinant and the inverse can be computed from the singular value decomposition
+   * of sigma.  Noting that covariance matrices are always symmetric and positive semi-definite,
+   * we can use the eigendecomposition. We also do not compute the inverse directly; noting
+   * that
+   *
+   *    sigma = U * D * U.t
+   *    inv(Sigma) = U * inv(D) * U.t
+   *               = (D^{-1/2}^ * U.t).t * (D^{-1/2}^ * U.t)
+   *
+   * and thus
+   *
+   *    -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U.t  * (x-mu))^2^
+   *
+   * To guard against singular covariance matrices, this method computes both the
+   * pseudo-determinant and the pseudo-inverse (Moore-Penrose).  Singular values are considered
+   * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and
+   * relation to the maximum singular value (same tolerance used by, e.g., Octave).
+   */
+  private def calculateCovarianceConstants: (BDM[Double], Double) = {
+    val eigSym.EigSym(d, u) = eigSym(cov.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t
+
+    // For numerical stability, values are considered to be non-zero only if they exceed tol.
+    // This prevents any inverted value from exceeding (eps * n * max(d))^-1
+    val tol = Utils.EPSILON * max(d) * d.length
+
+    try {
+      // log(pseudo-determinant) is sum of the logs of all non-zero singular values
+      val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum
+
+      // calculate the root-pseudo-inverse of the diagonal matrix of singular values
+      // by inverting the square root of all non-zero values
+      val pinvS = diag(new BDV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
+
+      (pinvS * u.t, -0.5 * (mean.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
+    } catch {
+      case uex: UnsupportedOperationException =>
+        throw new IllegalArgumentException("Covariance matrix has no non-zero singular values")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2c9a6d/mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala
new file mode 100644
index 0000000..44b122b
--- /dev/null
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl
+
+import org.apache.spark.ml.impl.Utils.EPSILON
+import org.apache.spark.ml.SparkMLFunSuite
+
+
+class UtilsSuite extends SparkMLFunSuite {
+
+  test("EPSILON") {
+    assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
+    assert(1.0 + EPSILON / 2.0 === 1.0, s"EPSILON is too big: $EPSILON.")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2c9a6d/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala
new file mode 100644
index 0000000..f9306ed
--- /dev/null
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat.distribution
+
+import org.apache.spark.ml.SparkMLFunSuite
+import org.apache.spark.ml.linalg.{Matrices, Vectors}
+import org.apache.spark.ml.util.TestingUtils._
+
+
+class MultivariateGaussianSuite extends SparkMLFunSuite {
+
+  test("univariate") {
+    val x1 = Vectors.dense(0.0)
+    val x2 = Vectors.dense(1.5)
+
+    val mu = Vectors.dense(0.0)
+    val sigma1 = Matrices.dense(1, 1, Array(1.0))
+    val dist1 = new MultivariateGaussian(mu, sigma1)
+    assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
+    assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
+
+    val sigma2 = Matrices.dense(1, 1, Array(4.0))
+    val dist2 = new MultivariateGaussian(mu, sigma2)
+    assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
+    assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
+  }
+
+  test("multivariate") {
+    val x1 = Vectors.dense(0.0, 0.0)
+    val x2 = Vectors.dense(1.0, 1.0)
+
+    val mu = Vectors.dense(0.0, 0.0)
+    val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
+    val dist1 = new MultivariateGaussian(mu, sigma1)
+    assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
+    assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
+
+    val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
+    val dist2 = new MultivariateGaussian(mu, sigma2)
+    assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
+    assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
+  }
+
+  test("multivariate degenerate") {
+    val x1 = Vectors.dense(0.0, 0.0)
+    val x2 = Vectors.dense(1.0, 1.0)
+
+    val mu = Vectors.dense(0.0, 0.0)
+    val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
+    val dist = new MultivariateGaussian(mu, sigma)
+    assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
+    assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
+  }
+
+  test("SPARK-11302") {
+    val x = Vectors.dense(629, 640, 1.7188, 618.19)
+    val mu = Vectors.dense(
+      1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697)
+    val sigma = Matrices.dense(4, 4, Array(
+      166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053,
+      169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484,
+      12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373,
+      164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207))
+    val dist = new MultivariateGaussian(mu, sigma)
+    // Agrees with R's dmvnorm: 7.154782e-05
+    assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2c9a6d/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index dfbc8b6..ac86e4c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -17,17 +17,21 @@
 
 package org.apache.spark.ml.clustering
 
+import breeze.linalg.{DenseVector => BDV}
 import org.apache.hadoop.fs.Path
 
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.impl.Utils.EPSILON
+import org.apache.spark.ml.linalg._
 import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
 import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.stat.distribution.MultivariateGaussian
 import org.apache.spark.ml.util._
-import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
-import org.apache.spark.mllib.linalg._
-import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
+import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM}
+import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
+  Vector => OldVector, Vectors => OldVectors, VectorUDT => OldVectorUDT}
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -56,34 +60,42 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+    SchemaUtils.checkColumnType(schema, $(featuresCol), new OldVectorUDT)
     SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
-    SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
+    SchemaUtils.appendColumn(schema, $(probabilityCol), new OldVectorUDT)
   }
 }
 
 /**
  * :: Experimental ::
- * Model fitted by GaussianMixture.
- * @param parentModel a model trained by spark.mllib.clustering.GaussianMixture.
+ *
+ * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
+ * are drawn from each Gaussian i with probability weights(i).
+ *
+ * @param weights Weight for each Gaussian distribution in the mixture.
+ *                This is a multinomial probability distribution over the k Gaussians,
+ *                where weights(i) is the weight for Gaussian i, and weights sum to 1.
+ * @param gaussians Array of [[MultivariateGaussian]] where gaussians(i) represents
+ *                  the Multivariate Gaussian (Normal) Distribution for Gaussian i
  */
 @Since("2.0.0")
 @Experimental
 class GaussianMixtureModel private[ml] (
     @Since("2.0.0") override val uid: String,
-    private val parentModel: MLlibGMModel)
+    @Since("2.0.0") val weights: Array[Double],
+    @Since("2.0.0") val gaussians: Array[MultivariateGaussian])
   extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {
 
   @Since("2.0.0")
   override def copy(extra: ParamMap): GaussianMixtureModel = {
-    val copied = new GaussianMixtureModel(uid, parentModel)
+    val copied = new GaussianMixtureModel(uid, weights, gaussians)
     copyValues(copied, extra).setParent(this.parent)
   }
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
-    val predUDF = udf((vector: Vector) => predict(vector))
-    val probUDF = udf((vector: Vector) => predictProbability(vector))
+    val predUDF = udf((vector: OldVector) => predict(vector.asML))
+    val probUDF = udf((vector: OldVector) => OldVectors.fromML(predictProbability(vector.asML)))
     dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
       .withColumn($(probabilityCol), probUDF(col($(featuresCol))))
   }
@@ -93,33 +105,32 @@ class GaussianMixtureModel private[ml] (
     validateAndTransformSchema(schema)
   }
 
-  private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+  private[clustering] def predict(features: Vector): Int = {
+    val r = predictProbability(features)
+    r.argmax
+  }
 
   private[clustering] def predictProbability(features: Vector): Vector = {
-    Vectors.dense(parentModel.predictSoft(features))
+    val probs: Array[Double] =
+      GaussianMixtureModel.computeProbabilities(features.toBreeze.toDenseVector, gaussians, weights)
+    Vectors.dense(probs)
   }
 
-  @Since("2.0.0")
-  def weights: Array[Double] = parentModel.weights
-
-  @Since("2.0.0")
-  def gaussians: Array[MultivariateGaussian] = parentModel.gaussians
-
   /**
    * Retrieve Gaussian distributions as a DataFrame.
    * Each row represents a Gaussian Distribution.
    * Two columns are defined: mean and cov.
    * Schema:
    * {{{
-   * root
-   * |-- mean: vector (nullable = true)
-   * |-- cov: matrix (nullable = true)
+   *  root
+   *   |-- mean: vector (nullable = true)
+   *   |-- cov: matrix (nullable = true)
    * }}}
    */
   @Since("2.0.0")
   def gaussiansDF: DataFrame = {
     val modelGaussians = gaussians.map { gaussian =>
-      (gaussian.mu, gaussian.sigma)
+      (OldVectors.fromML(gaussian.mean), OldMatrices.fromML(gaussian.cov))
     }
     val sc = SparkContext.getOrCreate()
     val sqlContext = SQLContext.getOrCreate(sc)
@@ -166,7 +177,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
   private[GaussianMixtureModel] class GaussianMixtureModelWriter(
       instance: GaussianMixtureModel) extends MLWriter {
 
-    private case class Data(weights: Array[Double], mus: Array[Vector], sigmas: Array[Matrix])
+    private case class Data(weights: Array[Double], mus: Array[OldVector], sigmas: Array[OldMatrix])
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
@@ -174,8 +185,8 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
       // Save model data: weights and gaussians
       val weights = instance.weights
       val gaussians = instance.gaussians
-      val mus = gaussians.map(_.mu)
-      val sigmas = gaussians.map(_.sigma)
+      val mus = gaussians.map(g => OldVectors.fromML(g.mean))
+      val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
       val data = Data(weights, mus, sigmas)
       val dataPath = new Path(path, "data").toString
       sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
@@ -193,26 +204,50 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
       val dataPath = new Path(path, "data").toString
       val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
       val weights = row.getSeq[Double](0).toArray
-      val mus = row.getSeq[Vector](1).toArray
-      val sigmas = row.getSeq[Matrix](2).toArray
+      val mus = row.getSeq[OldVector](1).toArray
+      val sigmas = row.getSeq[OldMatrix](2).toArray
       require(mus.length == sigmas.length, "Length of Mu and Sigma array must match")
       require(mus.length == weights.length, "Length of weight and Gaussian array must match")
 
-      val gaussians = (mus zip sigmas).map {
+      val gaussians = mus.zip(sigmas).map {
         case (mu, sigma) =>
-          new MultivariateGaussian(mu, sigma)
+          new MultivariateGaussian(mu.asML, sigma.asML)
       }
-      val model = new GaussianMixtureModel(metadata.uid, new MLlibGMModel(weights, gaussians))
+      val model = new GaussianMixtureModel(metadata.uid, weights, gaussians)
 
       DefaultParamsReader.getAndSetParams(model, metadata)
       model
     }
   }
+
+  /**
+   * Compute the probability (partial assignment) for each cluster for the given data point.
+   * @param features  Data point
+   * @param dists  Gaussians for model
+   * @param weights  Weights for each Gaussian
+   * @return  Probability (partial assignment) for each of the k clusters
+   */
+  private[clustering]
+  def computeProbabilities(
+      features: BDV[Double],
+      dists: Array[MultivariateGaussian],
+      weights: Array[Double]): Array[Double] = {
+    val p = weights.zip(dists).map {
+      case (weight, dist) => EPSILON + weight * dist.pdf(features)
+    }
+    val pSum = p.sum
+    var i = 0
+    while (i < weights.length) {
+      p(i) /= pSum
+      i += 1
+    }
+    p
+  }
 }
 
 /**
  * :: Experimental ::
- * GaussianMixture clustering.
+ * Gaussian Mixture clustering.
  */
 @Since("2.0.0")
 @Experimental
@@ -261,7 +296,7 @@ class GaussianMixture @Since("2.0.0") (
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
-    val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
+    val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: OldVector) => point }
 
     val algo = new MLlibGM()
       .setK($(k))
@@ -269,8 +304,11 @@ class GaussianMixture @Since("2.0.0") (
       .setSeed($(seed))
       .setConvergenceTol($(tol))
     val parentModel = algo.run(rdd)
-    val model = copyValues(new GaussianMixtureModel(uid, parentModel)
-      .setParent(this))
+    val gaussians = parentModel.gaussians.map { case g =>
+      new MultivariateGaussian(g.mu.asML, g.sigma.asML)
+    }
+    val model = copyValues(new GaussianMixtureModel(uid, parentModel.weights, gaussians))
+      .setParent(this)
     val summary = new GaussianMixtureSummary(model.transform(dataset),
       $(predictionCol), $(probabilityCol), $(featuresCol), $(k))
     model.setSummary(summary)

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2c9a6d/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index df6bb41..9d86817 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -108,8 +108,8 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
   test("read/write") {
     def checkModelData(model: GaussianMixtureModel, model2: GaussianMixtureModel): Unit = {
       assert(model.weights === model2.weights)
-      assert(model.gaussians.map(_.mu) === model2.gaussians.map(_.mu))
-      assert(model.gaussians.map(_.sigma) === model2.gaussians.map(_.sigma))
+      assert(model.gaussians.map(_.mean) === model2.gaussians.map(_.mean))
+      assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov))
     }
     val gm = new GaussianMixture()
     testEstimatorAndModelReadWrite(gm, dataset,

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2c9a6d/python/pyspark/ml/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 9740ec4..16ce02e 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -39,8 +39,9 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
     @since("2.0.0")
     def weights(self):
         """
-        Weights for each Gaussian distribution in the mixture, where weights[i] is
-        the weight for Gaussian i, and weights.sum == 1.
+        Weight for each Gaussian distribution in the mixture.
+        This is a multinomial probability distribution over the k Gaussians,
+        where weights[i] is the weight for Gaussian i, and weights sum to 1.
         """
         return self._call_java("weights")
 
@@ -50,11 +51,7 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
         """
         Retrieve Gaussian distributions as a DataFrame.
         Each row represents a Gaussian Distribution.
-        Two columns are defined: mean and cov.
-        Schema:
-        root
-        -- mean: vector (nullable = true)
-        -- cov: matrix (nullable = true)
+        The DataFrame has two columns: mean (Vector) and cov (Matrix).
         """
         return self._call_java("gaussiansDF")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org