You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/11/16 20:23:15 UTC

[systemds] branch master updated: [SYSTEMDS-2625]: Cleanup GMM built-in function, new logSumExp

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 756e447  [SYSTEMDS-2625]: Cleanup GMM built-in function, new logSumExp
756e447 is described below

commit 756e4471527281439fc62610cdb30f6c8cfe9301
Author: Shafaq Siddiqi <sh...@tugraz.at>
AuthorDate: Mon Nov 16 21:21:59 2020 +0100

    [SYSTEMDS-2625]: Cleanup GMM built-in function, new logSumExp
    
    This cleanup includes the functions reordering, branching cleanups and
    elimination of lists, now matrices are used for storing covariance. The
    private function logSumExp is replaced with a new builtin logSumExp().
    
    Closes #1023.
---
 scripts/builtin/gmm.dml                            | 471 ++++++++++-----------
 scripts/builtin/logSumExp.dml                      |  64 +++
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../sysds/runtime/matrix/data/LibCommonsMath.java  |   2 +-
 .../test/functions/builtin/BuiltinGMMTest.java     |  70 +--
 .../functions/builtin/BuiltinLogSumExpTest.java    | 107 +++++
 src/test/scripts/functions/builtin/GMM.dml         |  18 +-
 .../functions/builtin/{GMM.dml => logsumexp.R}     |  28 +-
 .../functions/builtin/{GMM.dml => logsumexp.dml}   |  14 +-
 9 files changed, 467 insertions(+), 308 deletions(-)

diff --git a/scripts/builtin/gmm.dml b/scripts/builtin/gmm.dml
index 7c327ca..a28dfd9 100644
--- a/scripts/builtin/gmm.dml
+++ b/scripts/builtin/gmm.dml
@@ -25,292 +25,249 @@
 
 # INPUT PARAMETERS:
 # ---------------------------------------------------------------------------------------------
-# NAME            TYPE    DEFAULT     MEANING
+# NAME            TYPE     DEFAULT     MEANING
 # ---------------------------------------------------------------------------------------------
-# X               Double   ---       Matrix X  
-# n_components    Integer  3         Number of n_components in the Gaussian mixture model
-# model           String   "VVV"     "VVV": unequal variance (full),each component has its own general covariance matrix
-#                                    "EEE": equal variance (tied), all components share the same general covariance matrix
-#                                    "VVI": spherical, unequal volume (diag), each component has its own diagonal covariance matrix 
-#                                    "VII": spherical, equal volume (spherical), each component has its own single variance
-# init_param      String  "kmeans"   initialize weights with "kmeans" or "random"
-# iterations      Integer  100       Number of iterations
-# reg_covar       Double   1e-6      regularization parameter for covariance matrix
-# tol             Double   0.000001  tolerance value for convergence 
+# X               Double   ---         Matrix X  
+# n_components    Integer  3           Number of n_components in the Gaussian mixture model
+# model           String   "VVV"       "VVV": unequal variance (full),each component has its own general covariance matrix
+#                                      "EEE": equal variance (tied), all components share the same general covariance matrix
+#                                      "VVI": spherical, unequal volume (diag), each component has its own diagonal 
+#                                             covariance matrix 
+#                                      "VII": spherical, equal volume (spherical), each component has its own single variance
+# init_param      String   "kmeans"    initialize weights with "kmeans" or "random"
+# iterations      Integer  100         Number of iterations
+# reg_covar       Double   1e-6        regularization parameter for covariance matrix
+# tol             Double   0.000001    tolerance value for convergence 
 # ---------------------------------------------------------------------------------------------
 
 
 #Output(s)
 # ---------------------------------------------------------------------------------------------
-# NAME            TYPE    DEFAULT     MEANING
+# NAME            TYPE     DEFAULT     MEANING
 # ---------------------------------------------------------------------------------------------
-# weight          Double   ---      A matrix whose [i,k]th entry is the probability that observation i in the test data belongs to the kth class
-# labels          Double   ---      Prediction matrix
-# df              Integer  ---      Number of estimated parameters
-# bic             Double   ---      Bayesian information criterion for best iteration
-
-
-
-
-m_gmm = function(Matrix[Double] X, Integer n_components = 1, String model = "VVV", String init_params = "kmeans", Integer iter = 100, Double reg_covar = 1e-6, Double tol = 0.000001, Boolean verbose = FALSE )
-return (Matrix[Double] weights, Matrix[Double] labels, Integer df, Double bic)
-{
+# weight          Double   ---         A matrix whose [i,k]th entry is the probability that observation i in the test data 
+#                                      belongs to the kth class
+# labels          Double   ---         Prediction matrix
+# df              Integer  ---         Number of estimated parameters
+# bic             Double   ---         Bayesian information criterion for best iteration
+
+
+m_gmm = function(Matrix[Double] X, Integer n_components = 3, String model = "VVV", String init_params = "kmeans", 
+  Integer iter = 100, Double reg_covar = 1e-6, Double tol = 0.000001, Integer seed = -1, Boolean verbose = FALSE )
+return (Matrix[Double] labels, Matrix[Double] predict_prob, Integer df, Double bic, 
+  Matrix[Double] mu, Matrix[Double] prec_chol, Matrix[Double] weight)
+{ 
   # sanity checks
-  if(model != "VVV" & model != "EEE" & model != "VVI" & model != "VII")
+  if(model != "VVV" & model != "EEE" &    model != "VVI" & model != "VII")
     stop("model not supported, should be in VVV, EEE, VVI, VII");
 
-  [labels, weights, norm] = fit(X, n_components, model, init_params, iter, reg_covar, tol)
+  [labels, predict_prob, norm, mu, prec_chol, weight] = fit(X, n_components, 
+    model, init_params, iter, reg_covar, tol, seed,  verbose)
   df = estimate_free_param(n_components, ncol(X), model)
-  bic = getBIC(nrow(X),norm,df)
+  bic = getBIC(nrow(X), norm, df)
+}
+
+fit = function(Matrix[Double] X, Integer n_components, String model, String init_params,
+  Integer iter, Double reg_covar, Double tol, Integer seed, Boolean verbose)
+return (Matrix[Double] label, Matrix[Double] predict_prob, Double log_prob_norm,
+    Matrix[Double] mean, Matrix[Double] precision_chol, Matrix[Double] weight)
+{
+  et = FALSE
+  lower_bound = 0
+  converged = FALSE
+  [weight, mean, sigma, precision_chol] = initialize_param(X, n_components,init_params, model, reg_covar, tol, seed)
+  i = 1
+  while(i <= iter & !converged & !et) {
+    prev_lower_bound = lower_bound
+    [log_prob_norm, log_resp, weighted_log_prob] = e_step(X, weight, mean, precision_chol, model)
+    [weight, mean, sigma, precision_chol, et] = m_step(X, log_resp, n_components, model, reg_covar)
+    lower_bound = log_prob_norm
+    change = lower_bound - prev_lower_bound
+    converged = (abs(change) < tol)
+    if(verbose) {
+      print("executing " +i+" iteration")
+      print("converged " +converged)
+      print("diff: "+(abs(change))+" tol: "+tol)
+    }
+    i = i+1
+  }
+  if(et) {
+    print("warning: did not converge because some components have ill-defined empirical covariance 
+    (i.e., singleton matrix or non-symmetric).
+    \nTry to decrease the number of components, or increase reg_covar")
+    label = rowIndexMax(weighted_log_prob)
+    predict_prob = exp(log_resp)
+  }
+  else {
+    [log_prob_norm, log_resp, weighted_log_prob] = e_step(X, weight, mean, precision_chol, model)
+    label = rowIndexMax(weighted_log_prob)
+    predict_prob = exp(log_resp)
+  }
+
 }
- 
-initialize_param = function(Matrix[Double] X, Integer n_components, String init_params, String model, Double reg_covar, Double tol)
-return (Matrix[Double] weight, Matrix[Double] mean, List[Unknown] sigma, List[Unknown] precision_chol) 
+
+initialize_param = function(Matrix[Double] X, Integer n_components, String init_params, 
+  String model, Double reg_covar, Double tol , Integer seed)
+return (Matrix[Double] weight, Matrix[Double] mean, Matrix[Double] sigma, Matrix[Double] precision_chol) 
 {
   # create responsibility matrix, resp[n_samples, n_components]
   resp = matrix(0, nrow(X), n_components)
-  if(init_params == "kmeans")
-  {
-    [C, Y] = kmeans(X=X, k=n_components, runs=10, max_iter=10, 
-                    eps=tol, is_verbose=FALSE, avg_sample_size_per_centroid=25)
-    resp = resp + t(seq(1,n_components))
-    resp = resp == Y
+  if(init_params == "kmeans") {
+    [C, Y] = kmeans(X=X, k=n_components, runs=10,
+      eps=tol, is_verbose=FALSE, avg_sample_size_per_centroid=100)
+    resp = ((resp + t(seq(1, n_components))) == Y)
   }
-  else if(init_params == "random")
-  {
-    resp = Rand(rows = nrow(X), cols=n_components)
+  else if(init_params == "random") {
+    resp = Rand(rows = nrow(X), cols=n_components, seed=seed)
     resp = resp/rowSums(resp)
   }
   else stop("invalid parameter value, expected kmeans or random found "+init_params) 
   
-  [weight, mean, sigma, precision_chol] = initialize(X, resp, n_components, model, reg_covar)
-}
-
-
-# Matrix/Vector Parameters
-# input: (X[n_samples, n_features], resp[n_samples, n_components])
-# output: (weight[n_samples, n_components], mean[n_components, n_features], sigma/prec_chol depends on model type)
-initialize = function(Matrix[Double] X, Matrix[Double] resp, Integer n_components, String model,  Double reg_covar)
-return (Matrix[Double] weight, Matrix[Double] mean, List[Unknown] sigma, List[Unknown] precision_chol)  
-{
-  n =  nrow(X)
   [weight, mean, sigma] = estimate_gaussian_param(X, resp, n_components, model, reg_covar)
-  weight = weight/n
-  precision_chol = compute_precision_cholesky(sigma, model)
+  weight = weight/nrow(X)
+  [precision_chol, et] = compute_precision_cholesky(sigma, model, n_components)
+  if(et)
+    stop("Fitting the mixture model failed because some components have ill-defined empirical covariance 
+    (i.e., singleton matrix or non-symmetric).
+    \nTry to decrease the number of components, or increase reg_covar")
 }
 
-estimate_gaussian_param = function(Matrix[Double] X, Matrix[Double] resp, Integer n_components, String model, Double reg_covar)
-return (Matrix[Double] weight, Matrix[Double] mean, List[Unknown] sigma)
+estimate_gaussian_param = function(Matrix[Double] X, Matrix[Double] resp, 
+  Integer n_components, String model, Double reg_covar)
+return (Matrix[Double] weight, Matrix[Double] mean, Matrix[Double] sigma)
 {
-  l = list()
-  n =  nrow(X)
+  MACHINE_PRECISION = 2.22e-16
   # estimate Gaussian parameter
-  nk = colSums(resp) + 2.220446049250313e-15
-  mu = (t(resp) %*% X) / t(nk)
-  sigma = list()
-  if(model == "VVV")
-    sigma = covariances_VVV(X, resp, mu, nk, reg_covar)
-  else if(model == "EEE")
-    sigma = covariances_EEE(X, resp, mu, nk, reg_covar)
-  else if(model ==  "VVI")
-    sigma = covariances_VVI(X, resp, mu, nk, reg_covar)
-  else if (model == "VII")
-    sigma = covariances_VII(X, resp, mu, nk, reg_covar)
-    
-  weight = nk
-  mean = mu # n_components * n_features
-}
-
-# Matrix/Vector Parameters/List
-# input: (X[n_samples, n_features], resp[n_samples, n_components],  mu[n_components, n_features], nk[1, n_components])
-# output: (sigma a list of length = n_components where each item in list is a covariance matrix of (n_features * n_features) dimensions)
-covariances_VVV = function(Matrix[Double] X, Matrix[Double] resp, Matrix[Double] mu, Matrix[Double] nk, Double reg_covar)
-return(List[Unknown] sigma)
-{
-  sigma = list()
-  for(k in 1:nrow(mu)) {
-    diff = X - mu[k,]
-    cov = (t(diff * resp[, k]) %*% diff) / as.scalar(nk[1,k])
+  weight = colSums(resp) + MACHINE_PRECISION # adding machine precision 
+  mean = (t(resp) %*% X) / t(weight) # mean dims:  n_components * n_features
+  
+  if(model == "VVV") {
+  # output: (sigma a list of length = n_components where each item in list is a covariance matrix of (
+  # n_features * n_features) dimensions) all rbind in a matrix form
+    sigma = matrix(0, 0, ncol(X))
+    for(k in 1:nrow(mean)) {
+      diff = X - mean[k,]
+      cov = (t(diff * resp[, k]) %*% diff) / as.scalar(weight[1,k])
+      cov = cov + diag(matrix(reg_covar, ncol(cov), 1))
+      sigma = rbind(sigma, cov)
+    }
+  }
+  else if(model == "EEE") {
+  # output: (sigma a list of length = 1 where  item in list is a covariance matrix of (n_features * n_features) dimensions)
+  # all rbind in a matrix form
+    avgX2 = t(X) %*% X
+    avgMean = (t(mean) * weight) %*% mean
+    cov = avgX2 - avgMean
+    cov = cov / sum(weight)
     cov = cov + diag(matrix(reg_covar, ncol(cov), 1))
-    sigma = append(sigma, cov)
+    sigma = cov
+  }
+  else if(model ==  "VVI") {
+  # output: (sigma a list of length = 1 where item in list is a covariance matrix of (n_components * n_features) dimensions)
+    avgX2 = (t(resp) %*% (X*X)) / t(weight)
+    avgMean = mean ^ 2
+    avgMean2 = mean * (t(resp) %*% X) / t(weight)
+    cov = avgX2 - 2 * avgMean + avgMean2 + reg_covar
+    sigma = cov
+  }
+  else if (model == "VII") {
+  # output: (sigma a list of length = 1 where item in list is a variance value for each component (1* n_components) dimensions)
+    avgX2 = (t(resp) %*% (X*X)) / t(weight)
+    avgMean = mean ^ 2
+    avgMean2 = mean * (t(resp) %*% X) / t(weight)
+    cov = avgX2 - 2 * avgMean + avgMean2 + reg_covar
+    sigma = rowMeans(cov)
   }
 }
 
-# Matrix/Vector Parameters/List
-# input: (X[n_samples, n_features], resp[n_samples, n_components],  mu[n_components, n_features], nk[1, n_components])
-# output: (sigma a list of length = 1 where  item in list is a covariance matrix of (n_features * n_features) dimensions)
-covariances_EEE = function(Matrix[Double] X, Matrix[Double] resp, Matrix[Double] mu, Matrix[Double] nk, Double reg_covar)
-return(List[Unknown] sigma)
-{
-  sigma = list()
-  avgX2 = t(X) %*% X
-  avgMean = (t(mu) * nk) %*% mu
-  cov = avgX2 - avgMean
-  cov = cov / sum(nk)
-  cov = cov + diag(matrix(reg_covar, ncol(cov), 1))
-  sigma = append(sigma, cov)
-}
-
-# Matrix/Vector Parameters/List
-# input: (X[n_samples, n_features], resp[n_samples, n_components],  mu[n_components, n_features], nk[1, n_components])
-# output: (sigma a list of length = 1 where item in list is a covariance matrix of (n_components * n_features) dimensions)
-covariances_VVI = function(Matrix[Double] X, Matrix[Double] resp, Matrix[Double] mu, Matrix[Double] nk, Double reg_covar)
-return(List[Unknown] sigma)
-{
-  sigma = list()
-  avgX2 = (t(resp) %*% (X*X)) / t(nk)
-  avgMean = mu ^ 2
-  avgMean2 = mu * (t(resp) %*% X) / t(nk)
-  cov = avgX2 - 2 * avgMean + avgMean2 + reg_covar
-  sigma = append(sigma, cov)
-}
-
-# Matrix/Vector Parameters/List
-# input: (X[n_samples, n_features], resp[n_samples, n_components],  mu[n_components, n_features], nk[1, n_components])
-# output: (sigma a list of length = 1 where item in list is a variance value for each component (1* n_components) dimensions)
-covariances_VII = function(Matrix[Double] X, Matrix[Double] resp, Matrix[Double] mu, Matrix[Double] nk, Double reg_covar)
-return(List[Unknown] sigma)
-{
-  sigma = list()
-  avgX2 = (t(resp) %*% (X*X)) / t(nk)
-  avgMean = mu ^ 2
-  avgMean2 = mu * (t(resp) %*% X) / t(nk)
-  cov = avgX2 - 2 * avgMean + avgMean2 + reg_covar
-  sigma = list(rowMeans(cov))
-}
-
-compute_precision_cholesky = function(List[Unknown] sigma, String model)
-return (List[Unknown] precision_chol)
+compute_precision_cholesky = function(Matrix[Double] sigma, String model, Integer n_components)
+return (Matrix[Double] precision_chol, Boolean earlyTermination )
 {
-  precision_chol = list()
-
+  earlyTermination = FALSE
   if(model == "VVV") {
-    comp = length(sigma)
-    for(k in 1:length(sigma)) {
-      cov = as.matrix(sigma[k]) 
+    index = 1; k = 1
+    precision_chol = matrix(0, 0, ncol(sigma))
+    while(k <= n_components) {
+      cov = sigma[index:(ncol(sigma)*k), ]
       isSPD = checkSPD(cov)
       if(isSPD) {
-        cov_chol = cholesky(cov)
+        cov_chol = choleskymatrix(cov)
         pre_chol = t(inv(cov_chol))
-        precision_chol = append(precision_chol, pre_chol)
-      } else 
-        stop("Fitting the mixture model failed because some components have ill-defined empirical covariance (i.e., singleton matrix or non-symmetric )."+ 
-        "\nTry to decrease the number of components, or increase reg_covar")
+        precision_chol = rbind(precision_chol, pre_chol)
+        index = index + ncol(sigma)
+        k = k+1
+      } else {
+        earlyTermination = TRUE;
+        k = n_components + 1
+      }
     }
   }
   else if(model == "EEE") {
-    cov = as.matrix(sigma[1])
+    cov = sigma
     isSPD = checkSPD(cov)
     if(isSPD) {
       cov_chol = cholesky(cov)
       pre_chol = t(inv(cov_chol))
-      precision_chol = append(precision_chol, pre_chol)
-    } else 
-      stop("Fitting the mixture model failed because some components have ill-defined empirical covariance (i.e., singleton matrix or non-symmetric)."+ 
-      "\nTry to decrease the number of components, or increase reg_covar")
+      precision_chol = pre_chol
+    } else
+      earlyTermination = TRUE
   }
   else {
-    cov = as.matrix(sigma[1])
+    cov = sigma
     if(sum(cov <= 0) > 0)
-      stop("Fitting the mixture model failed because some components have ill-defined empirical covariance (i.e., singleton matrix or non-symmetric)."+ 
-      "\nTry to decrease the number of components, or increase reg_covar")
+      earlyTermination = TRUE
     else {
-      precision_chol = append(precision_chol, 1.0/sqrt(cov))
+      precision_chol = 1.0/sqrt(cov)
     }
   }
 }
 
 # Expectation step
-e_step = function(Matrix[Double] X, Matrix[Double] w, Matrix[Double] mu, List[Unknown] precisions_cholesky, String model)
-return(Double norm, Matrix[Double] log_resp){
-  weighted_log_prob = estimate_weighted_log_prob(X, w, mu, precisions_cholesky, model)
-  log_prob_norm = logsumexp(weighted_log_prob)
+e_step = function(Matrix[Double] X, Matrix[Double] w, Matrix[Double] mu,
+  Matrix[Double] precisions_cholesky, String model)
+  return(Double norm, Matrix[Double] log_resp, Matrix[Double] weighted_log_prob)
+{
+  weighted_log_prob =  estimate_log_gaussian_prob(X, mu, precisions_cholesky, model) + log(w)
+  log_prob_norm = logSumExp(weighted_log_prob, "rows")
   log_resp = weighted_log_prob - log_prob_norm
   norm = mean(log_prob_norm)
 }
 
 # maximization Step
 m_step = function(Matrix[Double] X, Matrix[Double] log_resp, Integer n_components, String model, Double reg_covar)
-return (Matrix[Double] weight, Matrix[Double] mean, List[Unknown] sigma, List[Unknown] precision_chol) {
-  n =  nrow(X)
-  [weight, mean, sigma] = estimate_gaussian_param(X, exp(log_resp), n_components, model, reg_covar)
-  weight = weight/n
-  precision_chol = compute_precision_cholesky(sigma, model)
-}
-
-estimate_weighted_log_prob = function(Matrix[Double] X, Matrix[Double] w, Matrix[Double] mu, List[Unknown] precisions_cholesky, String model)
-return (Matrix[Double] weight_log_pro)
-{
-  weight_log_pro = estimate_log_prob(X, mu, precisions_cholesky, model) + estimate_log_weights(w)
-}
-
-estimate_log_weights = function(Matrix[Double] w)
-return (Matrix[Double] log_weight)
-{
-  log_weight = log(w)
-}
-
-estimate_log_prob = function(Matrix[Double] X, Matrix[Double] mu, List[Unknown] precisions_cholesky, String model)
-return (Matrix[Double] log_prob)
+  return (Matrix[Double] weight, Matrix[Double] mean, Matrix[Double] sigma, Matrix[Double] precision_chol, Boolean et) 
 {
-  log_prob = estimate_log_gaussian_prob(
-            X, mu, precisions_cholesky, model)
-}
-
-compute_log_det_cholesky = function(List[Unknown] mat_chol, String model, Integer d)
-return(Matrix[Double] log_det_cholesky)
-{
-  comp = length(mat_chol)
-
-  if(model == "VVV") {
-    log_det_chol = matrix(0, 1, comp)
-    for(k in 1:comp) {
-      mat = as.matrix(mat_chol[k])
-      log_det = sum(log(diag(t(mat))))   # have to take the log of diag elements only
-      log_det_chol[1,k] = log_det
-    }
-  }
-  else if(model == "EEE") {
-    mat = as.matrix(mat_chol[1])
-    log_det_chol = as.matrix(sum(log(diag(mat))))
-  }
-  else if(model ==  "VVI") {
-    mat = as.matrix(mat_chol[1])
-    log_det_chol = t(rowSums(log(mat)))
-  }
-  else if (model == "VII") {
-    mat = as.matrix(mat_chol[1])
-    log_det_chol = t(d * log(mat))
-  }
-  log_det_cholesky = log_det_chol
+  [weight, mean, sigma] = estimate_gaussian_param(X, exp(log_resp), n_components, model, reg_covar)
+  weight = weight/nrow(X)
+  [precision_chol, et] = compute_precision_cholesky(sigma, model, n_components)
 }
 
-estimate_log_gaussian_prob = function(Matrix[Double] X, Matrix[Double] mu, List[Unknown] prec_chol, String model)
-return(Matrix[Double] es_log_prob ) # nrow(X) * n_components
+estimate_log_gaussian_prob = function(Matrix[Double] X, Matrix[Double] mu, Matrix[Double] prec_chol, String model)
+  return(Matrix[Double] es_log_prob ) # nrow(X) * n_components
 {
-  n = nrow(X)
-  d = ncol(X)
   n_components = nrow(mu)
 
-  log_det = compute_log_det_cholesky(prec_chol, model, d)
+  log_det = compute_log_det_cholesky(prec_chol, model, ncol(X))
   if(model == "VVV") {
-    log_prob = matrix(0, n, n_components)
+    log_prob = matrix(0, nrow(X), n_components) 
+    i = 1
     for(k in 1:n_components) {
-      prec = as.matrix(prec_chol[k]) 
+      prec = prec_chol[i:(k*ncol(X)),]
       y = X %*% prec - mu[k,] %*% prec  # changing here t intro:  y = X %*% prec - mu[k,] %*% prec 
       log_prob[, k] = rowSums(y*y)
+      i = i + ncol(X)
     }
   }
   else if(model == "EEE") {
-    log_prob = matrix(0, n, n_components)
-    prec = as.matrix(prec_chol[1])
+    log_prob = matrix(0, nrow(X), n_components)
+    prec = prec_chol
     for(k in 1:n_components) {
       y = X %*% prec - mu[k,] %*% prec
       log_prob[, k] = rowSums(y*y) # TODO replace y*y with squared built-in
     }
   }
   else if(model ==  "VVI") {
-    prec = as.matrix(prec_chol[1])
+    prec = prec_chol
     precisions = prec^2
     bc_matrix = matrix(1,nrow(X), nrow(mu))
     log_prob = (bc_matrix*t(rowSums(mu^2 * precisions)) -
@@ -318,7 +275,7 @@ return(Matrix[Double] es_log_prob ) # nrow(X) * n_components
                     X^2 %*% t(precisions))
   }
   else if (model == "VII") {
-    prec = as.matrix(prec_chol[1])
+    prec = prec_chol
     precisions = prec^ 2
     bc_matrix = matrix(1,nrow(X), nrow(mu))
     log_prob = (bc_matrix * t(rowSums(mu^2) * precisions) -
@@ -327,21 +284,37 @@ return(Matrix[Double] es_log_prob ) # nrow(X) * n_components
   }
   if(ncol(log_det) == 1)
     log_det = matrix(1, 1, ncol(log_prob)) * log_det 
-  es_log_prob = -.5 * (d * log(2 * pi) + log_prob) + log_det
+  es_log_prob = -.5 * (ncol(X) * log(2 * pi) + log_prob) + log_det
 }
 
-logsumexp = function(Matrix[Double] M) # TODO replace with a built-in function logsumexp
-return(Matrix[Double] soft)
+compute_log_det_cholesky = function(Matrix[Double] mat_chol, String model, Integer d)
+  return(Matrix[Double] log_det_cholesky)
 {
-  max = max(M)
-  ds = M - max
-  sumOfexp = rowSums(exp(ds))
-  soft = max + log(sumOfexp)
+  comp = nrow(mat_chol)/ncol(mat_chol)
+
+  if(model == "VVV") {
+    log_det_chol = matrix(0, 1, comp)
+    i = 1
+    for(k in 1:comp) {
+      mat = mat_chol[i:(k*ncol(mat_chol))]
+      log_det = sum(log(diag(t(mat))))   # have to take the log of diag elements only
+      log_det_chol[1,k] = log_det
+      i = i + ncol(mat_chol) 
+    }
+  }
+  else if(model == "EEE")
+    log_det_chol = as.matrix(sum(log(diag(mat_chol))))
+  else if(model ==  "VVI")
+    log_det_chol = t(rowSums(log(mat_chol)))
+  else if (model == "VII")
+    log_det_chol = t(d * log(mat_chol))
+    
+  log_det_cholesky = log_det_chol
 }
 
 # compute the number of estimated parameters
 estimate_free_param = function(Integer n_components, Integer n_features, String model)
-return (Integer n_parameters)
+  return (Integer n_parameters)
 {
   if(model == "VVV")
     cov_param = n_components * n_features * (n_features + 1) / 2
@@ -351,54 +324,48 @@ return (Integer n_parameters)
     cov_param = n_components * n_features
   else if (model == "VII")
     cov_param = n_components
-  else 
+  else
     stop("invalid model expecting any of [VVV,EEE,VVI,VII], found "+model)
   mean_param = n_features * n_components
   
   n_parameters = as.integer( cov_param + mean_param + n_components - 1 )
 }
 
-fit = function(Matrix[Double] X, Integer n_components, String model, String init_params, Integer iter , Double reg_covar, Double tol)
-return (Matrix[Double] label, Matrix[Double] predict_prob, Double log_prob_norm)
-{
-  lower_bound = 0
-  converged = FALSE
-  n = nrow(X)
-  [weight, mean, sigma, precision_chol] = initialize_param(X, n_components,init_params, model, reg_covar, tol)
-  i = 1
-  while(i <= iter & !converged) {
-    prev_lower_bound = lower_bound
-    [log_prob_norm, log_resp] = e_step(X,weight, mean, precision_chol, model)
-    [weight, mean, sigma, precision_chol] = m_step(X, log_resp, n_components, model, reg_covar)
-    lower_bound = log_prob_norm
-    change = lower_bound - prev_lower_bound
-    if(abs(change) < tol)
-      converged = TRUE
-    i = i+1
-  }
-  [log_prob_norm, log_resp] = e_step(X,weight, mean, precision_chol, model)
-  label = rowIndexMax(log_resp)
-  predict_prob = exp(log_resp)
-}
-
 getBIC = function(Integer n, Double norm, Integer df)
-return(Double bic)
+  return(Double bic)
 {
   bic = -2 * norm * n + df * log(n)
 }
 
 # check if covariance matrix is symmetric and positive definite
 checkSPD = function(Matrix[Double] A)
-return(Boolean isSPD)
+  return(Boolean isSPD)
 {
   # abs(a - t(a)) <= (absoluteTolerance + relativeTolerance * abs(b))
-  sym = abs(A - t(A)) <= (1e-10 * abs(t(A)))
-  if(sum(sym == 0) == 0)
-  {
+  sym = abs(A - t(A)) <= (1e-4 + abs(t(A)))
+  if(sum(sym == 0) == 0) {
     [eval, evec] = eigen(A);
-    if(sum(eval < 0) == 0) #check positive definite
-      isSPD = TRUE
-    else  isSPD = FALSE
+    #check positive definite
+    isSPD = (sum(eval < 0) < 1e-4)
   }
   else isSPD = FALSE
+  # isSPD = TRUE
+}
+
+choleskymatrix = function(Matrix[Double] m)
+  return(Matrix[Double] L)
+{
+  rows = nrow(m)
+  cols = ncol(m)
+  L = diag(matrix(0, rows, 1))
+  for(i in 1:rows) {
+    for(k in 1:i) {
+      sum = sum(L[1:k, i] * L[1:k, k])
+      if(i == k)
+        L[k, i] = sqrt(m[i, i] - sum)
+      else
+        L[k, i] = (m[k, i] - sum) / L[k, k]
+    }
+  }
+  L = t(L)
 }
diff --git a/scripts/builtin/logSumExp.dml b/scripts/builtin/logSumExp.dml
new file mode 100644
index 0000000..3114f84
--- /dev/null
+++ b/scripts/builtin/logSumExp.dml
@@ -0,0 +1,64 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# ------------------------------------------
+# Built-in LOGSUMEXP 
+# ------------------------------------------
+
+
+# INPUT PARAMETERS:
+# ---------------------------------------------------------------------------------------------
+# NAME        TYPE    DEFAULT     MEANING
+# ---------------------------------------------------------------------------------------------
+# X           Double  ---         matrix M 
+# margin      String  none        if the logsumexp of rows is required set margin = "row"
+#                                 if the logsumexp of columns is required set margin = "col"
+#                                 if set to "none" then a single scalar is returned computing logsumexp of matrix
+# ---------------------------------------------------------------------------------------------
+
+
+#Output(s)
+# ---------------------------------------------------------------------------------------------
+# NAME        TYPE    DEFAULT     MEANING
+# ---------------------------------------------------------------------------------------------
+# output      Double  ---         A 1*1 matrix, row vector or column vector depends on margin value
+
+m_logSumExp = function(Matrix[Double] M, String margin = "none")
+return(Matrix[Double] output)
+{
+  if(margin == "rows") {
+    ds = M - rowMaxs(M)
+    rSumOfexp = rowSums(exp(ds))
+    output = rowMaxs(M) + log(rSumOfexp)
+  }
+  else if(margin == "cols") {
+    ds = M - colMaxs(M)
+    cSumOfexp = colSums(exp(ds))
+    output = colMaxs(M) + log(cSumOfexp)
+  }
+  else if(margin == "none") {
+    ds = M - max(M)
+    sumOfexp = sum(exp(ds))
+    output = as.matrix(max(M) + log(sumOfexp))
+  }
+  else 
+		stop("invalid margin value expecting rows, cols or none found: "+margin)
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 273a520..b4f9980 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -130,6 +130,7 @@ public enum Builtins {
 	LMDS("lmDS", true),
 	LMPREDICT("lmpredict", true),
 	LOG("log", false),
+	LOGSUMEXP("logSumExp", true),
 	LSTM("lstm", false, ReturnType.MULTI_RETURN),
 	LSTM_BACKWARD("lstm_backward", false, ReturnType.MULTI_RETURN),
 	LU("lu", false, ReturnType.MULTI_RETURN),
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
index b2a17f0..c822259 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
@@ -40,7 +40,7 @@ import org.apache.sysds.runtime.util.DataConverter;
  */
 public class LibCommonsMath 
 {
-	static final double RELATIVE_SYMMETRY_THRESHOLD = 1e-10;
+	static final double RELATIVE_SYMMETRY_THRESHOLD = 1e-6;
 
 	private LibCommonsMath() {
 		//prevent instantiation via private constructor
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMTest.java
index 59cfd3a..4c462ec 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMTest.java
@@ -39,8 +39,7 @@ public class BuiltinGMMTest extends AutomatedTestBase {
 	private final static double tol = 1e-3;
 	private final static double tol1 = 1e-4;
 	private final static double tol2 = 1e-5;
-	//private final static int rows = 100;
-	//private final static double spDense = 0.99;
+
 	private final static String DATASET = SCRIPT_DIR + "functions/transform/input/iris/iris.csv";
 
 	@Override
@@ -49,83 +48,104 @@ public class BuiltinGMMTest extends AutomatedTestBase {
 	}
 
 	@Test
-	public void testGMMM1() { runGMMTest(3, "VVV", "random", 10, 0.0000001, tol,true, LopProperties.ExecType.CP); }
+	public void testGMMMDefault() { runGMMTest(10, "VVV", "kmeans", 100,
+		1e-6, 0.000001,42, false, LopProperties.ExecType.CP); }
+		
+	@Test
+	public void testGMMM1() { runGMMTest(3, "VVV", "random", 100,
+			0.0000001, 0.0001,42, true, LopProperties.ExecType.CP); }
 
 	@Test
 	public void testGMMM2() {
-		runGMMTest(3, "EEE", "random", 150, 0.000001, tol1,true, LopProperties.ExecType.CP);
+		runGMMTest(3, "EEE", "random", 150,
+			0.000001, tol1,42,true, LopProperties.ExecType.CP);
 	}
 
 	@Test
 	public void testGMMM3() {
-		runGMMTest(3, "VVI", "random", 10, 0.000000001, tol,true, LopProperties.ExecType.CP);
+		runGMMTest(3, "VVI", "random", 10,
+			0.000000001, tol,42,true, LopProperties.ExecType.CP);
 	}
 
 	@Test
 	public void testGMMM4() {
-		runGMMTest(3, "VII", "random", 50, 0.000001, tol2,true, LopProperties.ExecType.CP);
+		runGMMTest(3, "VII", "random", 50,
+			0.000001, tol2,42,true, LopProperties.ExecType.CP);
 	}
 
 	@Test
 	public void testGMMM1Kmean() {
-		runGMMTest(3, "VVV", "kmeans", 10, 0.0000001, tol,true, LopProperties.ExecType.CP);
+		runGMMTest(3, "VVV", "kmeans", 10,
+			0.0000001, tol,42,true, LopProperties.ExecType.CP);
 	}
 
 	@Test
 	public void testGMMM2Kmean() {
-		runGMMTest(3, "EEE", "kmeans", 150, 0.000001, tol,true, LopProperties.ExecType.CP);
+		runGMMTest(3, "EEE", "kmeans", 150,
+			0.000001, tol,42,true, LopProperties.ExecType.CP);
 	}
 
 	@Test
 	public void testGMMM3Kmean() {
-		runGMMTest(3, "VVI", "kmeans", 10, 0.00000001, tol1,true, LopProperties.ExecType.CP);
+		runGMMTest(3, "VVI", "kmeans", 10,
+			0.00000001, tol1,42,true, LopProperties.ExecType.CP);
 	}
 
 	@Test
 	public void testGMMM4Kmean() {
-		runGMMTest(3, "VII", "kmeans", 50, 0.000001, tol2,true, LopProperties.ExecType.CP);
+		runGMMTest(3, "VII", "kmeans", 50,
+				0.000001, tol2,42,true, LopProperties.ExecType.CP);
 	}
 
 	@Test
-	public void testGMMM1Spark() { runGMMTest(3, "VVV", "random", 10, 0.0000001, tol,true, LopProperties.ExecType.SPARK); }
+	public void testGMMM1Spark() { runGMMTest(3, "VVV", "random", 10,
+			0.0000001, tol,42,true, LopProperties.ExecType.SPARK); }
 
 	@Test
 	public void testGMMM2Spark() {
-		runGMMTest(3, "EEE", "random", 50, 0.0000001, tol,true, LopProperties.ExecType.CP);
+		runGMMTest(3, "EEE", "random", 50,
+			0.0000001, tol,42,true, LopProperties.ExecType.CP);
 	}
 
 	@Test
 	public void testGMMMS3Spark() {
-		runGMMTest(3, "VVI", "random", 100, 0.000001, tol,true, LopProperties.ExecType.CP);
+		runGMMTest(3, "VVI", "random", 100,
+			0.000001, tol,42,true, LopProperties.ExecType.CP);
 	}
 
 	@Test
 	public void testGMMM4Spark() {
-		runGMMTest(3, "VII", "random", 100, 0.000001, tol1,true, LopProperties.ExecType.CP);
+		runGMMTest(3, "VII", "random", 100,
+			0.000001, tol1,42,true, LopProperties.ExecType.CP);
 	}
 
 	@Test
 	public void testGMMM1KmeanSpark() {
-		runGMMTest(3, "VVV", "kmeans", 100, 0.000001, tol2,false, LopProperties.ExecType.SPARK);
+		runGMMTest(3, "VVV", "kmeans", 100,
+			0.000001, tol2,42,false, LopProperties.ExecType.SPARK);
 	}
 
 	@Test
 	public void testGMMM2KmeanSpark() {
-		runGMMTest(3, "EEE", "kmeans", 50, 0.00000001, tol1,false, LopProperties.ExecType.SPARK);
+		runGMMTest(3, "EEE", "kmeans", 50,
+			0.00000001, tol1,42,false, LopProperties.ExecType.SPARK);
 	}
 
 	@Test
 	public void testGMMM3KmeanSpark() {
-		runGMMTest(3, "VVI", "kmeans", 100, 0.000001, tol,false, LopProperties.ExecType.SPARK);
+		runGMMTest(3, "VVI", "kmeans", 100,
+			0.000001, tol,42,false, LopProperties.ExecType.SPARK);
 	}
 
 	@Test
 	public void testGMMM4KmeanSpark() {
-		runGMMTest(3, "VII", "kmeans", 100, 0.000001, tol,false, LopProperties.ExecType.SPARK);
+		runGMMTest(3, "VII", "kmeans", 100,
+			0.000001, tol,42,false, LopProperties.ExecType.SPARK);
 	}
 
-	private void runGMMTest(int G_mixtures, String model, String init_param, int iter, double reg, double tol, boolean rewrite,
-			LopProperties.ExecType instType) {
+	private void runGMMTest(int G_mixtures, String model, String init_param, int iter,
+							double reg, double tol, int seed, boolean rewrite, LopProperties.ExecType instType) {
+
 		Types.ExecMode platformOld = setExecMode(instType);
 		OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
 
@@ -133,12 +153,13 @@ public class BuiltinGMMTest extends AutomatedTestBase {
 			loadTestConfiguration(getTestConfiguration(TEST_NAME));
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + TEST_NAME + ".dml";
-			programArgs = new String[] {"-args", DATASET, String.valueOf(G_mixtures), model, init_param, String.valueOf(iter),
-					String.valueOf(reg), String.valueOf(tol), output("B"), output("O")};
+			programArgs = new String[] {"-exec", "singlenode", "-args", DATASET,
+				String.valueOf(G_mixtures), model, init_param, String.valueOf(iter), String.valueOf(reg),
+				String.valueOf(tol), String.valueOf(seed), output("B"), output("O")};
 
 			fullRScriptName = HOME + TEST_NAME + ".R";
-			rCmd = "Rscript" + " " + fullRScriptName + " " + DATASET + " " + String
-					.valueOf(G_mixtures) + " " + model + " " + expectedDir();
+			rCmd = "Rscript" + " " + fullRScriptName + " " + DATASET + " " + 
+				String.valueOf(G_mixtures) + " " + model + " " + expectedDir();
 
 			runTest(true, false, null, -1);
 			runRScript(true);
@@ -146,7 +167,6 @@ public class BuiltinGMMTest extends AutomatedTestBase {
 			//compare matrices
 			HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("O");
 			HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("O");
-			System.out.println(dmlfile.values().iterator().next().doubleValue());
 			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
 		}
 		finally {
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLogSumExpTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLogSumExpTest.java
new file mode 100644
index 0000000..1a66e6a
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLogSumExpTest.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import java.util.HashMap;
+
+import org.junit.Test;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class BuiltinLogSumExpTest extends AutomatedTestBase
+{
+	private final static String TEST_NAME = "logsumexp";
+	private final static String TEST_DIR = "functions/builtin/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinLogSumExpTest.class.getSimpleName() + "/";
+
+	private final static double eps = 1e-4;
+	private final static int rows = 100;
+	private final static double spDense = 0.7;
+
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+	}
+
+	@Test
+	public void testrowlogSumExpCP() {
+		runlogSumExpTest("rows", ExecType.CP);
+	}
+
+	@Test
+	public void testrowlogSumExpSP() {
+		runlogSumExpTest("rows", ExecType.SPARK);
+	}
+
+	@Test
+	public void testcollogSumExpCP() {
+		runlogSumExpTest("cols", ExecType.CP);
+	}
+
+	@Test
+	public void testcollogSumExpSP() {
+		runlogSumExpTest("cols", ExecType.SPARK);
+	}
+
+	@Test
+	public void testlogSumExpCP() {
+		runlogSumExpTest("none", ExecType.CP);
+	}
+
+	@Test
+	public void testlogSumExpSP() {
+		runlogSumExpTest("none", ExecType.SPARK);
+	}
+	private void runlogSumExpTest(String axis, ExecType instType)
+	{
+		ExecMode platformOld = setExecMode(instType);
+		setOutputBuffering(false);
+		try
+		{
+			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[]{"-args", input("A"), axis, output("B") };
+
+			fullRScriptName = HOME + TEST_NAME + ".R";
+			rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + axis+ " " + expectedDir();
+
+			//generate actual dataset
+			double[][] A = getRandomMatrix(rows, 10, 10, 100, spDense, 7);
+			writeInputMatrixWithMTD("A", A, true);
+
+			runTest(true, false, null, -1);
+			runRScript(true);
+
+			//compare matrices
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("B");
+			HashMap<CellIndex, Double> rfile  = readRMatrixFromExpectedDir("B");
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+		}
+		finally {
+			rtplatform = platformOld;
+		}
+	}
+}
diff --git a/src/test/scripts/functions/builtin/GMM.dml b/src/test/scripts/functions/builtin/GMM.dml
index 76ecc5a..5bb557f 100644
--- a/src/test/scripts/functions/builtin/GMM.dml
+++ b/src/test/scripts/functions/builtin/GMM.dml
@@ -19,14 +19,14 @@
 #
 #-------------------------------------------------------------
 
-
 X = read($1, data_type = "frame", format = "csv")
-X = as.matrix(X[, 2:5])
-[prob, labels, df, bic] = gmm(X=X, n_components = $2,  model = $3,  init_params = $4, iter = $5, reg_covar = $6, tol = $7, verbose=TRUE)
+X = as.matrix(X[, 2:ncol(X)-1])
+
+[labels, prob, df, bic, mu, prec_chol, w] = gmm(X=X, n_components = $2,
+  model = $3,  init_params = $4, iter = $5, 
+  reg_covar = $6, tol = $7, seed=$8, verbose=TRUE)
+
 out = (rowMaxs(prob) < 0.7)
-cluster = colSums(prob == rowMaxs(prob))
-# print("clusters "+toString(cluster))
-# print("bic "+bic)
-# print("df "+df)
-write(prob, $8)
-write(out, $9)
+
+write(prob, $9)
+write(out, $10)
diff --git a/src/test/scripts/functions/builtin/GMM.dml b/src/test/scripts/functions/builtin/logsumexp.R
similarity index 70%
copy from src/test/scripts/functions/builtin/GMM.dml
copy to src/test/scripts/functions/builtin/logsumexp.R
index 76ecc5a..5d0cc6e 100644
--- a/src/test/scripts/functions/builtin/GMM.dml
+++ b/src/test/scripts/functions/builtin/logsumexp.R
@@ -20,13 +20,21 @@
 #-------------------------------------------------------------
 
 
-X = read($1, data_type = "frame", format = "csv")
-X = as.matrix(X[, 2:5])
-[prob, labels, df, bic] = gmm(X=X, n_components = $2,  model = $3,  init_params = $4, iter = $5, reg_covar = $6, tol = $7, verbose=TRUE)
-out = (rowMaxs(prob) < 0.7)
-cluster = colSums(prob == rowMaxs(prob))
-# print("clusters "+toString(cluster))
-# print("bic "+bic)
-# print("df "+df)
-write(prob, $8)
-write(out, $9)
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+M = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+opt = args[2]
+if(opt == "rows") {
+  O = rowLogSumExps(M)
+} else if (opt == "cols") {
+  O = t(colLogSumExps(M))
+} else {
+  O = logSumExp(M)
+}
+
+log(sum(exp(M)))
+
+writeMM(as(O, "CsparseMatrix"), paste(args[3], "B", sep=""));
\ No newline at end of file
diff --git a/src/test/scripts/functions/builtin/GMM.dml b/src/test/scripts/functions/builtin/logsumexp.dml
similarity index 70%
copy from src/test/scripts/functions/builtin/GMM.dml
copy to src/test/scripts/functions/builtin/logsumexp.dml
index 76ecc5a..4f333dd 100644
--- a/src/test/scripts/functions/builtin/GMM.dml
+++ b/src/test/scripts/functions/builtin/logsumexp.dml
@@ -19,14 +19,6 @@
 #
 #-------------------------------------------------------------
 
-
-X = read($1, data_type = "frame", format = "csv")
-X = as.matrix(X[, 2:5])
-[prob, labels, df, bic] = gmm(X=X, n_components = $2,  model = $3,  init_params = $4, iter = $5, reg_covar = $6, tol = $7, verbose=TRUE)
-out = (rowMaxs(prob) < 0.7)
-cluster = colSums(prob == rowMaxs(prob))
-# print("clusters "+toString(cluster))
-# print("bic "+bic)
-# print("df "+df)
-write(prob, $8)
-write(out, $9)
+M = read($1)
+O = logSumExp(M, $2)
+write(O, $3)