You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ss...@apache.org on 2020/08/03 16:31:48 UTC

[systemds] branch master updated: [SYSTEMDS-200] Gaussian Mixture Model, a new builtin for unsupervised learning. This commit also contains a minor change in LibCommonsMath.java, the RELATIVE_SYMMETRY_THRESHOLD value for Cholesky decomposition is updated from 1e-14 to 1e-10. Date: Mon Aug 3 18:30:03 2020 +0200

This is an automated email from the ASF dual-hosted git repository.

ssiddiqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 5591a6e  [SYSTEMDS-200] Gaussian Mixture Model, a new builtin for unsupervised learning. This commit also contains a minor change in LibCommonsMath.java, the RELATIVE_SYMMETRY_THRESHOLD value for Cholesky decomposition is updated from 1e-14 to 1e-10. Date:      Mon Aug 3 18:30:03 2020 +0200
5591a6e is described below

commit 5591a6e0d493e04f5f0789e4cbad25f4bbd92323
Author: Shafaq Siddiqi <sh...@tugraz.at>
AuthorDate: Mon Aug 3 18:30:03 2020 +0200

    [SYSTEMDS-200] Gaussian Mixture Model, a new builtin for unsupervised learning.
    This commit also contains a minor change in LibCommonsMath.java, the RELATIVE_SYMMETRY_THRESHOLD value for Cholesky decomposition is updated from 1e-14 to 1e-10.
    Date:      Mon Aug 3 18:30:03 2020 +0200
---
 dev/Tasks-obsolete.txt                             |   1 +
 docs/site/builtins-reference.md                    |  39 ++
 scripts/builtin/gmm.dml                            | 418 +++++++++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../sysds/runtime/matrix/data/LibCommonsMath.java  |   6 +-
 .../test/functions/builtin/BuiltinGMMTest.java     | 156 ++++++++
 src/test/scripts/functions/builtin/GMM.R           |  34 ++
 src/test/scripts/functions/builtin/GMM.dml         |  32 ++
 .../functions/transform/input/iris/iris.csv        |   2 +-
 9 files changed, 686 insertions(+), 3 deletions(-)

diff --git a/dev/Tasks-obsolete.txt b/dev/Tasks-obsolete.txt
index 17dd708..4ef8729 100644
--- a/dev/Tasks-obsolete.txt
+++ b/dev/Tasks-obsolete.txt
@@ -180,6 +180,7 @@ SYSTEMDS-190 New Builtin Functions III
  * 197 Builtin function for functional dependency discovery           OK
  * 198 Extended slice finding (classification)                        OK
  * 199 Builtin function Multinominal Logistic Regression Predict      OK
+ * 200 Builtin function Gaussian Mixture Model                        OK
 
 SYSTEMDS-200 Various Fixes
  * 201 Fix spark append instruction zero columns                      OK
diff --git a/docs/site/builtins-reference.md b/docs/site/builtins-reference.md
index 6931d42..d86a4ad 100644
--- a/docs/site/builtins-reference.md
+++ b/docs/site/builtins-reference.md
@@ -57,6 +57,7 @@ limitations under the License.
     * [`outlier`-Function](#outlier-function)
     * [`toOneHot`-Function](#toOneHOt-function)
     * [`winsorize`-Function](#winsorize-function)
+    * [`gmm`-Function](#gmm-function)
     
     
 # Introduction
@@ -1093,3 +1094,41 @@ winsorize(X)
 X = rand(rows=10, cols=10,min = 1, max=9)
 Y = winsorize(X=X)
 ```
+
+## `gmm`-Function
+
+The `gmm`-function implements builtin Gaussian Mixture Model with four different types of 
+covariance matrices i.e., VVV, EEE, VVI, VII and two initialization methods namely "kmeans" and "random".
+
+### Usage
+```r
+gmm(X=X, n_components = 3,  model = "VVV",  init_params = "random", iter = 100, reg_covar = 0.000001, tol = 0.0001, verbose=TRUE)
+```
+
+
+### Arguments
+| Name          | Type             | Default    | Description |
+| :------       | :-------------   | --------   | :---------- |
+| X             | Double           | ---        | Matrix X of feature vectors.|
+| n_components             | Integer           | 3        | Number of n_components in the Gaussian mixture model |
+| model     | String          | "VVV"| "VVV": unequal variance (full),each component has its own general covariance matrix<br><br>"EEE": equal variance (tied), all components share the same general covariance matrix<br><br>"VVI": spherical, unequal volume (diag), each component has its own diagonal covariance matrix<br><br>"VII": spherical, equal volume (spherical), each component has its own single variance |
+| init_param   | String          | "kmeans"         | initialize weights with "kmeans" or "random"|
+| iterations       | Integer           | 100    |  Number of iterations|
+| reg_covar         | Double           | 1e-6        | regularization parameter for covariance matrix|
+| tol | Double          | 0.000001        |tolerance value for convergence |
+| verbose       | Boolean          | False      | Set to true to print intermediate results.|
+
+
+### Returns
+| Name    | Type           | Default  | Description |
+| :------ | :------------- | -------- | :---------- |
+| weight    |      Double  | ---      |A matrix whose [i,k]th entry is the probability that observation i in the test data belongs to the kth class|
+|labels   |       Double  | ---     | Prediction matrix|
+|df |              Integer  |---  |    Number of estimated parameters|
+| bic |             Double  | ---  |    Bayesian information criterion for best iteration|
+
+### Example
+```r
+X = read($1)
+[labels, df, bic] = gmm(X=X, n_components = 3,  model = "VVV",  init_params = "random", iter = 100, reg_covar = 0.000001, tol = 0.0001, verbose=TRUE)
+```
diff --git a/scripts/builtin/gmm.dml b/scripts/builtin/gmm.dml
new file mode 100644
index 0000000..2eafe13
--- /dev/null
+++ b/scripts/builtin/gmm.dml
@@ -0,0 +1,418 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# ------------------------------------------
+# Gaussian Mixture Model
+# ------------------------------------------
+
+# INPUT PARAMETERS:
+# ---------------------------------------------------------------------------------------------
+# NAME            TYPE    DEFAULT     MEANING
+# ---------------------------------------------------------------------------------------------
+# X               Double   ---       Matrix X  
+# n_components    Integer  3         Number of n_components in the Gaussian mixture model
+# model           String   "VVV"     "VVV": unequal variance (full),each component has its own general covariance matrix
+#                                    "EEE": equal variance (tied), all components share the same general covariance matrix
+#                                    "VVI": spherical, unequal volume (diag), each component has its own diagonal covariance matrix 
+#                                    "VII": spherical, equal volume (spherical), each component has its own single variance
+# init_param      String  "kmeans"   initialize weights with "kmeans" or "random"
+# iterations      Integer  100       Number of iterations
+# reg_covar       Double   1e-6      regularization parameter for covariance matrix
+# tol             Double   0.000001  tolerance value for convergence 
+# ---------------------------------------------------------------------------------------------
+
+
+#Output(s)
+# ---------------------------------------------------------------------------------------------
+# NAME            TYPE    DEFAULT     MEANING
+# ---------------------------------------------------------------------------------------------
+# weight          Double   ---      A matrix whose [i,k]th entry is the probability that observation i in the test data belongs to the kth class
+# labels          Double   ---      Prediction matrix
+# df              Integer  ---      Number of estimated parameters
+# bic             Double   ---      Bayesian information criterion for best iteration
+
+
+
+
+m_gmm = function(Matrix[Double] X, Integer n_components = 1, String model = "VVV", String init_params = "kmeans", Integer iter = 100, Double reg_covar = 1e-6, Double tol = 0.000001, Boolean verbose = FALSE )
+return (Matrix[Double] weights, Matrix[Double] labels, Integer df, Double bic)
+{ 
+
+  # sanity checks
+  if(model != "VVV" & model != "EEE" & model != "VVI" & model != "VII")
+    stop("model not supported, should be in VVV, EEE, VVI, VII");
+
+  [labels, weights, norm] = fit(X, n_components, model, init_params, iter, reg_covar, tol)
+  df = estimate_free_param(n_components, ncol(X), model)
+  bic = getBIC(nrow(X),norm,df)
+}
+ 
+initialize_param = function(Matrix[Double] X, Integer n_components, String init_params, String model, Double reg_covar, Double tol)
+return (Matrix[Double] weight, Matrix[Double] mean, List[Unknown] sigma, List[Unknown] precision_chol) 
+{
+  # create responsibility matrix, resp[n_samples, n_components]
+  resp = matrix(0, nrow(X), n_components)
+  if(init_params == "kmeans")
+  {
+    [C, Y] = kmeans(X, n_components, 10, 10, tol, FALSE, 25)
+    resp = resp + t(seq(1,n_components))
+    resp = resp == Y
+  }
+  else if(init_params == "random")
+  {
+    resp = Rand(rows = nrow(X), cols=n_components)   
+    resp = resp/rowSums(resp)
+  }
+  else stop("invalid parameter value, expected kmeans or random found "+init_params) 
+  
+  [weight, mean, sigma, precision_chol] = initialize(X, resp, n_components, model, reg_covar)
+}
+
+
+# Matrix/Vector Parameters
+# input: (X[n_samples, n_features], resp[n_samples, n_components])
+# output: (weight[n_samples, n_components], mean[n_components, n_features], sigma/prec_chol depends on model type)
+initialize = function(Matrix[Double] X, Matrix[Double] resp, Integer n_components, String model,  Double reg_covar)
+return (Matrix[Double] weight, Matrix[Double] mean, List[Unknown] sigma, List[Unknown] precision_chol)  
+{
+  n =  nrow(X)
+  [weight, mean, sigma] = estimate_gaussian_param(X, resp, n_components, model, reg_covar)
+  weight = weight/n
+  precision_chol = compute_precision_cholesky(sigma, model)
+}
+
+estimate_gaussian_param = function(Matrix[Double] X, Matrix[Double] resp, Integer n_components, String model, Double reg_covar)
+return (Matrix[Double] weight, Matrix[Double] mean, List[Unknown] sigma)
+{
+  l = list()
+  n =  nrow(X)
+  # estimate Gaussian parameter
+  nk = colSums(resp) + 2.220446049250313e-15
+  mu = (t(resp) %*% X) / t(nk)
+  sigma = list()
+  if(model == "VVV")
+    sigma = covariances_VVV(X, resp, mu, nk, reg_covar) 
+  else if(model == "EEE")
+    sigma = covariances_EEE(X, resp, mu, nk, reg_covar)
+  else if(model ==  "VVI")
+    sigma = covariances_VVI(X, resp, mu, nk, reg_covar)
+  else if (model == "VII")
+    sigma = covariances_VII(X, resp, mu, nk, reg_covar)
+    
+  weight = nk
+  mean = mu # n_components * n_features     
+}
+
+# Matrix/Vector Parameters/List
+# input: (X[n_samples, n_features], resp[n_samples, n_components],  mu[n_components, n_features], nk[1, n_components])
+# output: (sigma a list of length = n_components where each item in list is a covariance matrix of (n_features * n_features) dimensions)
+covariances_VVV = function(Matrix[Double] X, Matrix[Double] resp, Matrix[Double] mu, Matrix[Double] nk, Double reg_covar)
+return(List[Unknown] sigma)
+{
+  sigma = list()
+  for(k in 1:nrow(mu))
+  {
+    diff = X - mu[k,]
+    cov = (t(diff * resp[, k]) %*% diff) / as.scalar(nk[1,k])
+    cov = cov + diag(matrix(reg_covar, ncol(cov), 1))
+    sigma = append(sigma, cov)
+  }       
+}
+
+# Matrix/Vector Parameters/List
+# input: (X[n_samples, n_features], resp[n_samples, n_components],  mu[n_components, n_features], nk[1, n_components])
+# output: (sigma a list of length = 1 where  item in list is a covariance matrix of (n_features * n_features) dimensions)
+covariances_EEE = function(Matrix[Double] X, Matrix[Double] resp, Matrix[Double] mu, Matrix[Double] nk, Double reg_covar)
+return(List[Unknown] sigma)
+{
+  sigma = list()
+  avgX2 = t(X) %*% X
+  avgMean = (t(mu) * nk) %*% mu
+  cov = avgX2 - avgMean
+  cov = cov / sum(nk)
+  cov = cov + diag(matrix(reg_covar, ncol(cov), 1))
+  sigma = append(sigma, cov)
+}
+
+# Matrix/Vector Parameters/List
+# input: (X[n_samples, n_features], resp[n_samples, n_components],  mu[n_components, n_features], nk[1, n_components])
+# output: (sigma a list of length = 1 where item in list is a covariance matrix of (n_components * n_features) dimensions)
+covariances_VVI = function(Matrix[Double] X, Matrix[Double] resp, Matrix[Double] mu, Matrix[Double] nk, Double reg_covar)
+return(List[Unknown] sigma)
+{
+  sigma = list()
+  avgX2 = (t(resp) %*% (X*X)) / t(nk)
+  avgMean = mu ^ 2
+  avgMean2 = mu * (t(resp) %*% X) / t(nk)
+  cov = avgX2 - 2 * avgMean + avgMean2 + reg_covar
+  sigma = append(sigma, cov)
+}
+
+# Matrix/Vector Parameters/List
+# input: (X[n_samples, n_features], resp[n_samples, n_components],  mu[n_components, n_features], nk[1, n_components])
+# output: (sigma a list of length = 1 where item in list is a variance value for each component (1* n_components) dimensions)
+covariances_VII = function(Matrix[Double] X, Matrix[Double] resp, Matrix[Double] mu, Matrix[Double] nk, Double reg_covar)
+return(List[Unknown] sigma)
+{
+  sigma = list()
+  avgX2 = (t(resp) %*% (X*X)) / t(nk)
+  avgMean = mu ^ 2
+  avgMean2 = mu * (t(resp) %*% X) / t(nk)
+  cov = avgX2 - 2 * avgMean + avgMean2 + reg_covar
+  sigma = list(rowMeans(cov))
+}
+
+compute_precision_cholesky = function(List[Unknown] sigma, String model)
+return (List[Unknown] precision_chol)
+{
+  precision_chol = list()
+
+  if(model == "VVV") {
+    comp = length(sigma)
+    for(k in 1:length(sigma)) {
+      cov = as.matrix(sigma[k]) 
+      isSPD = checkSPD(cov)
+      if(isSPD) {
+        cov_chol = cholesky(cov)
+        pre_chol = t(inv(cov_chol))                      
+        precision_chol = append(precision_chol, pre_chol)
+      } else 
+        stop("Fitting the mixture model failed because some components have ill-defined empirical covariance (i.e., singleton matrix or non-symmetric )."+ 
+        "\nTry to decrease the number of components, or increase reg_covar")
+    }      
+  }
+  else if(model == "EEE") {
+    cov = as.matrix(sigma[1])
+    isSPD = checkSPD(cov)
+    if(isSPD) {
+      cov_chol = cholesky(cov)
+      pre_chol = t(inv(cov_chol))
+      precision_chol = append(precision_chol, pre_chol)
+    } else 
+      stop("Fitting the mixture model failed because some components have ill-defined empirical covariance (i.e., singleton matrix or non-symmetric)."+ 
+      "\nTry to decrease the number of components, or increase reg_covar")
+  }
+  else {
+    cov = as.matrix(sigma[1])
+    if(sum(cov <= 0) > 0)
+      stop("Fitting the mixture model failed because some components have ill-defined empirical covariance (i.e., singleton matrix or non-symmetric)."+ 
+      "\nTry to decrease the number of components, or increase reg_covar")
+    else {
+      precision_chol = append(precision_chol, 1.0/sqrt(cov))
+    }   
+  }
+}
+
+# Expectation step
+e_step = function(Matrix[Double] X, Matrix[Double] w, Matrix[Double] mu, List[Unknown] precisions_cholesky, String model)
+return(Double norm, Matrix[Double] log_resp){
+  weighted_log_prob = estimate_weighted_log_prob(X, w, mu, precisions_cholesky, model)
+  log_prob_norm = logsumexp(weighted_log_prob)
+  log_resp = weighted_log_prob - log_prob_norm
+  norm = mean(log_prob_norm)
+}
+
+# maximization Step
+m_step = function(Matrix[Double] X, Matrix[Double] log_resp, Integer n_components, String model, Double reg_covar)
+return (Matrix[Double] weight, Matrix[Double] mean, List[Unknown] sigma, List[Unknown] precision_chol) {
+  n =  nrow(X)
+  [weight, mean, sigma] = estimate_gaussian_param(X, exp(log_resp), n_components, model, reg_covar)
+  weight = weight/n
+  precision_chol = compute_precision_cholesky(sigma, model)
+}
+
+estimate_weighted_log_prob = function(Matrix[Double] X, Matrix[Double] w, Matrix[Double] mu, List[Unknown] precisions_cholesky, String model)
+return (Matrix[Double] weight_log_pro)
+{
+  weight_log_pro = estimate_log_prob(X, mu, precisions_cholesky, model) + estimate_log_weights(w)
+}
+  
+estimate_log_weights = function(Matrix[Double] w)
+return (Matrix[Double] log_weight)
+{
+  log_weight = log(w)
+}
+
+estimate_log_prob = function(Matrix[Double] X, Matrix[Double] mu, List[Unknown] precisions_cholesky, String model)
+return (Matrix[Double] log_prob)
+{
+  log_prob = estimate_log_gaussian_prob(
+            X, mu, precisions_cholesky, model)
+}
+
+compute_log_det_cholesky = function(List[Unknown] mat_chol, String model, Integer d)
+return(Matrix[Double] log_det_cholesky)
+{
+  comp = length(mat_chol)
+
+  if(model == "VVV")
+  { 
+    log_det_chol = matrix(0, 1, comp) 
+    for(k in 1:comp)
+    {
+      mat = as.matrix(mat_chol[k])
+      log_det = sum(log(diag(t(mat))))   # have to take the log of diag elements only
+      log_det_chol[1,k] = log_det
+    }      
+  }
+  else if(model == "EEE")
+  { 
+    mat = as.matrix(mat_chol[1])
+    log_det_chol = as.matrix(sum(log(diag(mat))))
+  }
+  else if(model ==  "VVI")
+  {
+    mat = as.matrix(mat_chol[1])
+    log_det_chol = t(rowSums(log(mat)))
+  }
+  else if (model == "VII")
+  {
+    mat = as.matrix(mat_chol[1])
+    log_det_chol = t(d * log(mat))
+  }
+  log_det_cholesky = log_det_chol
+}
+
+estimate_log_gaussian_prob = function(Matrix[Double] X, Matrix[Double] mu, List[Unknown] prec_chol, String model)
+return(Matrix[Double] es_log_prob ) # nrow(X) * n_components
+{
+  n = nrow(X)
+  d = ncol(X)
+  n_components = nrow(mu)
+
+  log_det = compute_log_det_cholesky(prec_chol, model, d)
+  if(model == "VVV")
+  { 
+    log_prob = matrix(0, n, n_components) 
+    for(k in 1:n_components)
+    {
+      prec = as.matrix(prec_chol[k]) 
+      y = X %*% prec - mu[k,] %*% prec  # changing here t intro:  y = X %*% prec - mu[k,] %*% prec 
+      log_prob[, k] = rowSums(y*y)
+    }      
+  }
+  else if(model == "EEE")
+  { 
+    log_prob = matrix(0, n, n_components) 
+    prec = as.matrix(prec_chol[1])
+    for(k in 1:n_components) {
+      y = X %*% prec - mu[k,] %*% prec
+      log_prob[, k] = rowSums(y*y) # TODO replace y*y with squared built-in
+    }
+  }
+  else if(model ==  "VVI")
+  {
+    prec = as.matrix(prec_chol[1])
+    precisions = prec^2
+    bc_matrix = matrix(1,nrow(X), nrow(mu))
+    log_prob = (bc_matrix*t(rowSums(mu^2 * precisions)) -
+                    2. * (X %*% t(mu * precisions)) +
+                    X^2 %*% t(precisions))
+  }
+  else if (model == "VII")
+  {
+    prec = as.matrix(prec_chol[1])
+    precisions = prec^ 2
+    bc_matrix = matrix(1,nrow(X), nrow(mu))
+    log_prob = (bc_matrix * t(rowSums(mu^2) * precisions) -
+                    2 * X %*% t(mu * precisions) +
+                    rowSums(X*X) %*% t(precisions) ) # TODO replace rowSums(X*X) with squared rowNorm() built-in
+  }
+  if(ncol(log_det) == 1)
+    log_det = matrix(1, 1, ncol(log_prob)) * log_det 
+  es_log_prob = -.5 * (d * log(2 * pi) + log_prob) + log_det
+}
+
+logsumexp = function(Matrix[Double] M) # TODO replace with a built-in function logsumexp
+return(Matrix[Double] soft)
+{
+  max = max(M)
+  ds = M - max
+  sumOfexp = rowSums(exp(ds))
+  soft = max + log(sumOfexp)
+}
+
+# compute the number of estimated parameters
+estimate_free_param = function(Integer n_components, Integer n_features, String model)
+return (Integer n_parameters)
+{
+  if(model == "VVV")
+    cov_param = n_components * n_features * (n_features + 1) / 2
+  else if(model == "EEE")
+    cov_param = n_features * (n_features + 1) / 2
+  else if (model == "VVI")
+    cov_param = n_components * n_features
+  else if (model == "VII")
+    cov_param = n_components
+  else 
+    stop("invalid model expecting any of [VVV,EEE,VVI,VII], found "+model)
+  mean_param = n_features * n_components
+  
+  n_parameters = as.integer( cov_param + mean_param + n_components - 1 )
+
+}
+
+fit = function(Matrix[Double] X, Integer n_components, String model, String init_params, Integer iter , Double reg_covar, Double tol)
+return (Matrix[Double] label, Matrix[Double] predict_prob, Double log_prob_norm)
+{
+
+  lower_bound = 0
+  converged = FALSE
+  n = nrow(X)
+  [weight, mean, sigma, precision_chol] = initialize_param(X, n_components,init_params, model, reg_covar, tol)
+  i = 1
+  while(i <= iter & !converged) {
+    prev_lower_bound = lower_bound
+    [log_prob_norm, log_resp] = e_step(X,weight, mean, precision_chol, model)
+    [weight, mean, sigma, precision_chol] = m_step(X, log_resp, n_components, model, reg_covar)
+    lower_bound = log_prob_norm
+    change = lower_bound - prev_lower_bound
+    if(abs(change) < tol)
+      converged = TRUE
+    i = i+1   
+  }
+  [log_prob_norm, log_resp] = e_step(X,weight, mean, precision_chol, model)
+  label = rowIndexMax(log_resp)
+  predict_prob = exp(log_resp)
+}
+
+getBIC = function(Integer n, Double norm, Integer df)
+return(Double bic)
+{
+  bic = -2 * norm * n + df * log(n)
+}
+
+# check if covariance matrix is symmetric and positive definite
+checkSPD = function(Matrix[Double] A)
+return(Boolean isSPD)
+{
+  # abs(a - t(a)) <= (absoluteTolerance + relativeTolerance * abs(b))
+  sym = abs(A - t(A)) <= (1e-10 * abs(t(A)))
+  if(sum(sym == 0) == 0)
+  {
+    [eval, evec] = eigen(A);
+    if(sum(eval < 0) == 0) #check positive definite
+      isSPD = TRUE
+    else  isSPD = FALSE
+  }
+  else isSPD = FALSE
+}
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index b6733d2..22134ea 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -95,6 +95,7 @@ public enum Builtins {
 	EVAL("eval", false),
 	FLOOR("floor", false),
 	GLM("glm", true),
+	GMM("gmm", true),
 	GNMF("gnmf", true),
 	GRID_SEARCH("gridSearch", true),
 	HYPERBAND("hyperband", true),
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
index 985be59..b2a17f0 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
@@ -39,7 +39,9 @@ import org.apache.sysds.runtime.util.DataConverter;
  * matrix inverse, matrix decompositions (QR, LU, Eigen), solve 
  */
 public class LibCommonsMath 
-{	
+{
+	static final double RELATIVE_SYMMETRY_THRESHOLD = 1e-10;
+
 	private LibCommonsMath() {
 		//prevent instantiation via private constructor
 	}
@@ -260,7 +262,7 @@ public class LibCommonsMath
 	private static MatrixBlock computeCholesky(Array2DRowRealMatrix in) {
 		if ( !in.isSquare() )
 			throw new DMLRuntimeException("Input to cholesky() must be square matrix -- given: a " + in.getRowDimension() + "x" + in.getColumnDimension() + " matrix.");
-		CholeskyDecomposition cholesky = new CholeskyDecomposition(in, 1e-14,
+		CholeskyDecomposition cholesky = new CholeskyDecomposition(in, RELATIVE_SYMMETRY_THRESHOLD,
 			CholeskyDecomposition.DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD);
 		RealMatrix rmL = cholesky.getL();
 		return DataConverter.convertToMatrixBlock(rmL.getData());
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMTest.java
new file mode 100644
index 0000000..a637dd9
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class BuiltinGMMTest extends AutomatedTestBase {
+	private final static String TEST_NAME = "GMM";
+	private final static String TEST_DIR = "functions/builtin/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinGMMTest.class.getSimpleName() + "/";
+
+	private final static double eps = 1;
+	private final static double tol = 1e-3;
+	private final static double tol1 = 1e-4;
+	private final static double tol2 = 1e-5;
+	private final static int rows = 100;
+	private final static double spDense = 0.99;
+	private final static String DATASET = SCRIPT_DIR + "functions/transform/input/iris/iris.csv";
+
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"}));
+	}
+
+	@Test
+	public void testGMMM1() { runGMMTest(3, "VVV", "random", 10, 0.0000001, tol,true, LopProperties.ExecType.CP); }
+
+	@Test
+	public void testGMMM2() {
+		runGMMTest(3, "EEE", "random", 150, 0.000001, tol1,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMM3() {
+		runGMMTest(3, "VVI", "random", 10, 0.000000001, tol,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMM4() {
+		runGMMTest(3, "VII", "random", 50, 0.000001, tol2,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMM1Kmean() {
+		runGMMTest(3, "VVV", "kmeans", 10, 0.0000001, tol,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMM2Kmean() {
+		runGMMTest(3, "EEE", "kmeans", 150, 0.000001, tol,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMM3Kmean() {
+		runGMMTest(3, "VVI", "kmeans", 10, 0.00000001, tol1,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMM4Kmean() {
+		runGMMTest(3, "VII", "kmeans", 50, 0.000001, tol2,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMM1Spark() { runGMMTest(3, "VVV", "random", 10, 0.0000001, tol,true, LopProperties.ExecType.SPARK); }
+
+	@Test
+	public void testGMMM2Spark() {
+		runGMMTest(3, "EEE", "random", 50, 0.0000001, tol,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMMS3Spark() {
+		runGMMTest(3, "VVI", "random", 100, 0.000001, tol,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMM4Spark() {
+		runGMMTest(3, "VII", "random", 100, 0.000001, tol1,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMM1KmeanSpark() {
+		runGMMTest(3, "VVV", "kmeans", 100, 0.000001, tol2,false, LopProperties.ExecType.SPARK);
+	}
+
+	@Test
+	public void testGMMM2KmeanSpark() {
+		runGMMTest(3, "EEE", "kmeans", 50, 0.00000001, tol1,false, LopProperties.ExecType.SPARK);
+	}
+
+	@Test
+	public void testGMMM3KmeanSpark() {
+		runGMMTest(3, "VVI", "kmeans", 100, 0.000001, tol,false, LopProperties.ExecType.SPARK);
+	}
+
+	@Test
+	public void testGMMM4KmeanSpark() {
+		runGMMTest(3, "VII", "kmeans", 100, 0.000001, tol,false, LopProperties.ExecType.SPARK);
+	}
+
+	private void runGMMTest(int G_mixtures, String model, String init_param, int iter, double reg, double tol, boolean rewrite,
+			LopProperties.ExecType instType) {
+		Types.ExecMode platformOld = setExecMode(instType);
+		OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+
+		try {
+			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[] {"-args", DATASET, String.valueOf(G_mixtures), model, init_param, String.valueOf(iter),
+					String.valueOf(reg), String.valueOf(tol), output("B"), output("O")};
+
+			fullRScriptName = HOME + TEST_NAME + ".R";
+			rCmd = "Rscript" + " " + fullRScriptName + " " + DATASET + " " + String
+					.valueOf(G_mixtures) + " " + model + " " + expectedDir();
+
+			runTest(true, false, null, -1);
+			runRScript(true);
+
+			//compare matrices
+			HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("O");
+			HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromFS("O");
+			System.out.println(dmlfile.values().iterator().next().doubleValue());
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+		}
+		finally {
+			rtplatform = platformOld;
+		}
+	}
+}
diff --git a/src/test/scripts/functions/builtin/GMM.R b/src/test/scripts/functions/builtin/GMM.R
new file mode 100644
index 0000000..99d0b9a
--- /dev/null
+++ b/src/test/scripts/functions/builtin/GMM.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library(mclust, quietly = TRUE)
+library("matrixStats") 
+
+X = iris[,1:4]
+fit =  Mclust(X, modelType = args[3], G=args[2])
+prob = fit$z
+out = rowMaxs(fit$z) < 0.7
+out = as.double(out)
+
+writeMM(as(prob, "CsparseMatrix"), paste(args[4], "B", sep=""))
+writeMM(as(out, "CsparseMatrix"), paste(args[4], "O", sep=""))
\ No newline at end of file
diff --git a/src/test/scripts/functions/builtin/GMM.dml b/src/test/scripts/functions/builtin/GMM.dml
new file mode 100644
index 0000000..76ecc5a
--- /dev/null
+++ b/src/test/scripts/functions/builtin/GMM.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1, data_type = "frame", format = "csv")
+X = as.matrix(X[, 2:5])
+[prob, labels, df, bic] = gmm(X=X, n_components = $2,  model = $3,  init_params = $4, iter = $5, reg_covar = $6, tol = $7, verbose=TRUE)
+out = (rowMaxs(prob) < 0.7)
+cluster = colSums(prob == rowMaxs(prob))
+# print("clusters "+toString(cluster))
+# print("bic "+bic)
+# print("df "+df)
+write(prob, $8)
+write(out, $9)
diff --git a/src/test/scripts/functions/transform/input/iris/iris.csv b/src/test/scripts/functions/transform/input/iris/iris.csv
index 11b46be..692bb3c 100644
--- a/src/test/scripts/functions/transform/input/iris/iris.csv
+++ b/src/test/scripts/functions/transform/input/iris/iris.csv
@@ -148,4 +148,4 @@ ID,Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species
 147,6.3,2.5,5,1.9,virginica
 148,6.5,3,5.2,2,virginica
 149,6.2,3.4,5.4,2.3,virginica
-150,5.9,3,5.1,1.8,virginica
+150,5.9,3,5.1,1.8,virginica
\ No newline at end of file