You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/01/12 06:31:33 UTC
spark git commit: SPARK-5018 [MLlib] [WIP] Make MultivariateGaussian
public
Repository: spark
Updated Branches:
refs/heads/master f38ef6586 -> 2130de9d8
SPARK-5018 [MLlib] [WIP] Make MultivariateGaussian public
Moving MutlivariateGaussian from private[mllib] to public. The class uses Breeze vectors internally, so this involves creating a public interface using MLlib vectors and matrices.
This initial commit provides public construction, accessors for mean/covariance, density and log-density.
Other potential methods include entropy and sample generation.
Author: Travis Galoppo <tj...@columbia.edu>
Closes #3923 from tgaloppo/spark-5018 and squashes the following commits:
2b15587 [Travis Galoppo] Style correction
b4121b4 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' into spark-5018
e30a100 [Travis Galoppo] Made mu, sigma private[mllib] members of MultivariateGaussian Moved MultivariateGaussian (and test suite) from stat.impl to stat.distribution (required updates in GaussianMixture{EM,Model}.scala) Marked MultivariateGaussian as @DeveloperApi Fixed style error
9fa3bb7 [Travis Galoppo] Style improvements
91a5fae [Travis Galoppo] Rearranged equation for part of density function
8c35381 [Travis Galoppo] Fixed accessor methods to match member variable names. Modified calculations to avoid log(pow(x,y)) calculations
0943dc4 [Travis Galoppo] SPARK-5018
4dee9e1 [Travis Galoppo] SPARK-5018
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2130de9d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2130de9d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2130de9d
Branch: refs/heads/master
Commit: 2130de9d8f50f52b9b2d296b377df81d840546b3
Parents: f38ef65
Author: Travis Galoppo <tj...@columbia.edu>
Authored: Sun Jan 11 21:31:16 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sun Jan 11 21:31:16 2015 -0800
----------------------------------------------------------------------
.../mllib/clustering/GaussianMixtureEM.scala | 11 +-
.../mllib/clustering/GaussianMixtureModel.scala | 2 +-
.../distribution/MultivariateGaussian.scala | 134 +++++++++++++++++++
.../mllib/stat/impl/MultivariateGaussian.scala | 100 --------------
.../MultivariateGaussianSuite.scala | 69 ++++++++++
.../stat/impl/MultivariateGaussianSuite.scala | 70 ----------
6 files changed, 210 insertions(+), 176 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/2130de9d/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
index b3c5631..d8e1346 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -20,10 +20,11 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable.IndexedSeq
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
-import org.apache.spark.rdd.RDD
+
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
-import org.apache.spark.mllib.stat.impl.MultivariateGaussian
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
/**
@@ -134,7 +135,7 @@ class GaussianMixtureEM private (
// derived from the samples
val (weights, gaussians) = initialModel match {
case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
- new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix)
+ new MultivariateGaussian(mu, sigma)
})
case None => {
@@ -176,8 +177,8 @@ class GaussianMixtureEM private (
}
// Need to convert the breeze matrices to MLlib matrices
- val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) }
- val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) }
+ val means = Array.tabulate(k) { i => gaussians(i).mu }
+ val sigmas = Array.tabulate(k) { i => gaussians(i).sigma }
new GaussianMixtureModel(weights, means, sigmas)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2130de9d/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index b461ea4..416cad0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -21,7 +21,7 @@ import breeze.linalg.{DenseVector => BreezeVector}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrix, Vector}
-import org.apache.spark.mllib.stat.impl.MultivariateGaussian
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/2130de9d/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
new file mode 100644
index 0000000..fd186b5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat.distribution
+
+import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym}
+
+import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
+import org.apache.spark.mllib.util.MLUtils
+
+/**
+ * :: DeveloperApi ::
+ * This class provides basic functionality for a Multivariate Gaussian (Normal) Distribution. In
+ * the event that the covariance matrix is singular, the density will be computed in a
+ * reduced dimensional subspace under which the distribution is supported.
+ * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]])
+ *
+ * @param mu The mean vector of the distribution
+ * @param sigma The covariance matrix of the distribution
+ */
+@DeveloperApi
+class MultivariateGaussian (
+ val mu: Vector,
+ val sigma: Matrix) extends Serializable {
+
+ require(sigma.numCols == sigma.numRows, "Covariance matrix must be square")
+ require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size")
+
+ private val breezeMu = mu.toBreeze.toDenseVector
+
+ /**
+ * private[mllib] constructor
+ *
+ * @param mu The mean vector of the distribution
+ * @param sigma The covariance matrix of the distribution
+ */
+ private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = {
+ this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma))
+ }
+
+ /**
+ * Compute distribution dependent constants:
+ * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t
+ * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+ */
+ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
+
+ /** Returns density of this multivariate Gaussian at given point, x */
+ def pdf(x: Vector): Double = {
+ pdf(x.toBreeze.toDenseVector)
+ }
+
+ /** Returns the log-density of this multivariate Gaussian at given point, x */
+ def logpdf(x: Vector): Double = {
+ logpdf(x.toBreeze.toDenseVector)
+ }
+
+ /** Returns density of this multivariate Gaussian at given point, x */
+ private[mllib] def pdf(x: DBV[Double]): Double = {
+ math.exp(logpdf(x))
+ }
+
+ /** Returns the log-density of this multivariate Gaussian at given point, x */
+ private[mllib] def logpdf(x: DBV[Double]): Double = {
+ val delta = x - breezeMu
+ val v = rootSigmaInv * delta
+ u + v.t * v * -0.5
+ }
+
+ /**
+ * Calculate distribution dependent components used for the density function:
+ * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))
+ * where k is length of the mean vector.
+ *
+ * We here compute distribution-fixed parts
+ * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+ * and
+ * D^(-1/2)^ * U, where sigma = U * D * U.t
+ *
+ * Both the determinant and the inverse can be computed from the singular value decomposition
+ * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite,
+ * we can use the eigendecomposition. We also do not compute the inverse directly; noting
+ * that
+ *
+ * sigma = U * D * U.t
+ * inv(Sigma) = U * inv(D) * U.t
+ * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U)
+ *
+ * and thus
+ *
+ * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^
+ *
+ * To guard against singular covariance matrices, this method computes both the
+ * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered
+ * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and
+ * relation to the maximum singular value (same tolerance used by, e.g., Octave).
+ */
+ private def calculateCovarianceConstants: (DBM[Double], Double) = {
+ val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t
+
+ // For numerical stability, values are considered to be non-zero only if they exceed tol.
+ // This prevents any inverted value from exceeding (eps * n * max(d))^-1
+ val tol = MLUtils.EPSILON * max(d) * d.length
+
+ try {
+ // log(pseudo-determinant) is sum of the logs of all non-zero singular values
+ val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum
+
+ // calculate the root-pseudo-inverse of the diagonal matrix of singular values
+ // by inverting the square root of all non-zero values
+ val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
+
+ (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
+ } catch {
+ case uex: UnsupportedOperationException =>
+ throw new IllegalArgumentException("Covariance matrix has no non-zero singular values")
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/2130de9d/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
deleted file mode 100644
index bc7f6c5..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.stat.impl
-
-import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym}
-
-import org.apache.spark.mllib.util.MLUtils
-
-/**
- * This class provides basic functionality for a Multivariate Gaussian (Normal) Distribution. In
- * the event that the covariance matrix is singular, the density will be computed in a
- * reduced dimensional subspace under which the distribution is supported.
- * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]])
- *
- * @param mu The mean vector of the distribution
- * @param sigma The covariance matrix of the distribution
- */
-private[mllib] class MultivariateGaussian(
- val mu: DBV[Double],
- val sigma: DBM[Double]) extends Serializable {
-
- /**
- * Compute distribution dependent constants:
- * rootSigmaInv = D^(-1/2) * U, where sigma = U * D * U.t
- * u = (2*pi)^(-k/2) * det(sigma)^(-1/2)
- */
- private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
-
- /** Returns density of this multivariate Gaussian at given point, x */
- def pdf(x: DBV[Double]): Double = {
- val delta = x - mu
- val v = rootSigmaInv * delta
- u * math.exp(v.t * v * -0.5)
- }
-
- /**
- * Calculate distribution dependent components used for the density function:
- * pdf(x) = (2*pi)^(-k/2) * det(sigma)^(-1/2) * exp( (-1/2) * (x-mu).t * inv(sigma) * (x-mu) )
- * where k is length of the mean vector.
- *
- * We here compute distribution-fixed parts
- * (2*pi)^(-k/2) * det(sigma)^(-1/2)
- * and
- * D^(-1/2) * U, where sigma = U * D * U.t
- *
- * Both the determinant and the inverse can be computed from the singular value decomposition
- * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite,
- * we can use the eigendecomposition. We also do not compute the inverse directly; noting
- * that
- *
- * sigma = U * D * U.t
- * inv(Sigma) = U * inv(D) * U.t
- * = (D^{-1/2} * U).t * (D^{-1/2} * U)
- *
- * and thus
- *
- * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2} * U * (x-mu))^2
- *
- * To guard against singular covariance matrices, this method computes both the
- * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered
- * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and
- * relation to the maximum singular value (same tolerance used by, e.g., Octave).
- */
- private def calculateCovarianceConstants: (DBM[Double], Double) = {
- val eigSym.EigSym(d, u) = eigSym(sigma) // sigma = u * diag(d) * u.t
-
- // For numerical stability, values are considered to be non-zero only if they exceed tol.
- // This prevents any inverted value from exceeding (eps * n * max(d))^-1
- val tol = MLUtils.EPSILON * max(d) * d.length
-
- try {
- // pseudo-determinant is product of all non-zero singular values
- val pdetSigma = d.activeValuesIterator.filter(_ > tol).reduce(_ * _)
-
- // calculate the root-pseudo-inverse of the diagonal matrix of singular values
- // by inverting the square root of all non-zero values
- val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
-
- (pinvS * u, math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(pdetSigma, -0.5))
- } catch {
- case uex: UnsupportedOperationException =>
- throw new IllegalArgumentException("Covariance matrix has no non-zero singular values")
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/2130de9d/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
new file mode 100644
index 0000000..fac2498
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat.distribution
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{ Vectors, Matrices }
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext {
+ test("univariate") {
+ val x1 = Vectors.dense(0.0)
+ val x2 = Vectors.dense(1.5)
+
+ val mu = Vectors.dense(0.0)
+ val sigma1 = Matrices.dense(1, 1, Array(1.0))
+ val dist1 = new MultivariateGaussian(mu, sigma1)
+ assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
+ assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
+
+ val sigma2 = Matrices.dense(1, 1, Array(4.0))
+ val dist2 = new MultivariateGaussian(mu, sigma2)
+ assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
+ assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
+ }
+
+ test("multivariate") {
+ val x1 = Vectors.dense(0.0, 0.0)
+ val x2 = Vectors.dense(1.0, 1.0)
+
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
+ val dist1 = new MultivariateGaussian(mu, sigma1)
+ assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
+ assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
+
+ val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
+ val dist2 = new MultivariateGaussian(mu, sigma2)
+ assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
+ assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
+ }
+
+ test("multivariate degenerate") {
+ val x1 = Vectors.dense(0.0, 0.0)
+ val x2 = Vectors.dense(1.0, 1.0)
+
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
+ val dist = new MultivariateGaussian(mu, sigma)
+ assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
+ assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/2130de9d/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala
deleted file mode 100644
index d58f258..0000000
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.stat.impl
-
-import org.scalatest.FunSuite
-
-import breeze.linalg.{ DenseVector => BDV, DenseMatrix => BDM }
-
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
-
-class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext {
- test("univariate") {
- val x1 = new BDV(Array(0.0))
- val x2 = new BDV(Array(1.5))
-
- val mu = new BDV(Array(0.0))
- val sigma1 = new BDM(1, 1, Array(1.0))
- val dist1 = new MultivariateGaussian(mu, sigma1)
- assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
- assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
-
- val sigma2 = new BDM(1, 1, Array(4.0))
- val dist2 = new MultivariateGaussian(mu, sigma2)
- assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
- assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
- }
-
- test("multivariate") {
- val x1 = new BDV(Array(0.0, 0.0))
- val x2 = new BDV(Array(1.0, 1.0))
-
- val mu = new BDV(Array(0.0, 0.0))
- val sigma1 = new BDM(2, 2, Array(1.0, 0.0, 0.0, 1.0))
- val dist1 = new MultivariateGaussian(mu, sigma1)
- assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
- assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
-
- val sigma2 = new BDM(2, 2, Array(4.0, -1.0, -1.0, 2.0))
- val dist2 = new MultivariateGaussian(mu, sigma2)
- assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
- assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
- }
-
- test("multivariate degenerate") {
- val x1 = new BDV(Array(0.0, 0.0))
- val x2 = new BDV(Array(1.0, 1.0))
-
- val mu = new BDV(Array(0.0, 0.0))
- val sigma = new BDM(2, 2, Array(1.0, 1.0, 1.0, 1.0))
- val dist = new MultivariateGaussian(mu, sigma)
- assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
- assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org