You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/04/30 08:13:06 UTC

spark git commit: [SPARK-14831][SPARKR] Make the SparkR MLlib API more consistent with Spark

Repository: spark
Updated Branches:
  refs/heads/master 43b149fb8 -> bc36fe6e8


[SPARK-14831][SPARKR] Make the SparkR MLlib API more consistent with Spark

## What changes were proposed in this pull request?

This PR splits the MLlib algorithms into two flavors:
 - the R flavor, which tries to mimic the existing R API for these algorithms (and works as an S4 specialization for Spark dataframes)
 - the Spark flavor, which follows the same API and naming conventions as the rest of the MLlib algorithms in the other languages

In practice, the former calls the latter.

## How was this patch tested?

The tests for the various algorithms were adapted to be run against both interfaces.

Author: Timothy Hunter <ti...@databricks.com>

Closes #12789 from thunterdb/14831.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc36fe6e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc36fe6e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc36fe6e

Branch: refs/heads/master
Commit: bc36fe6e896ab0e64f6334b1e3fd6386d0c38238
Parents: 43b149f
Author: Timothy Hunter <ti...@databricks.com>
Authored: Fri Apr 29 23:13:03 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Apr 29 23:13:03 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                        |   7 +-
 R/pkg/R/generics.R                     |  16 +--
 R/pkg/R/mllib.R                        | 155 +++++++++++++++++-----------
 R/pkg/inst/tests/testthat/test_mllib.R | 141 ++++++++++++++++++++++++-
 4 files changed, 247 insertions(+), 72 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bc36fe6e/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 647db22..d2aebb3 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -12,12 +12,13 @@ export("print.jobj")
 
 # MLlib integration
 exportMethods("glm",
+              "spark.glm",
               "predict",
               "summary",
-              "kmeans",
+              "spark.kmeans",
               "fitted",
-              "naiveBayes",
-              "survreg")
+              "spark.naiveBayes",
+              "spark.survreg")
 
 # Job group lifecycle management methods
 export("setJobGroup",

http://git-wip-us.apache.org/repos/asf/spark/blob/bc36fe6e/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 3db8925..a37cdf2 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1181,6 +1181,10 @@ setGeneric("window", function(x, ...) { standardGeneric("window") })
 #' @export
 setGeneric("year", function(x) { standardGeneric("year") })
 
+#' @rdname spark.glm
+#' @export
+setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })
+
 #' @rdname glm
 #' @export
 setGeneric("glm")
@@ -1193,21 +1197,21 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") })
 #' @export
 setGeneric("rbind", signature = "...")
 
-#' @rdname kmeans
+#' @rdname spark.kmeans
 #' @export
-setGeneric("kmeans")
+setGeneric("spark.kmeans", function(data, k, ...) { standardGeneric("spark.kmeans") })
 
 #' @rdname fitted
 #' @export
 setGeneric("fitted")
 
-#' @rdname naiveBayes
+#' @rdname spark.naiveBayes
 #' @export
-setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })
+setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })
 
-#' @rdname survreg
+#' @rdname spark.survreg
 #' @export
-setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })
+setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
 
 #' @rdname ml.save
 #' @export

http://git-wip-us.apache.org/repos/asf/spark/blob/bc36fe6e/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index c2326ea..4f62d7c 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -17,6 +17,14 @@
 
 # mllib.R: Provides methods for MLlib integration
 
+# Integration with R's standard functions.
+# Most of MLlib's argorithms are provided in two flavours:
+# - a specialization of the default R methods (glm). These methods try to respect
+#   the inputs and the outputs of R's method to the largest extent, but some small differences
+#   may exist.
+# - a set of methods that reflect the arguments of the other languages supported by Spark. These
+#   methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc.
+
 #' @title S4 class that represents a generalized linear model
 #' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper
 #' @export
@@ -39,6 +47,54 @@ setClass("KMeansModel", representation(jobj = "jobj"))
 
 #' Fits a generalized linear model
 #'
+#' Fits a generalized linear model against a Spark DataFrame.
+#'
+#' @param data SparkDataFrame for training.
+#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#'                operators are supported, including '~', '.', ':', '+', and '-'.
+#' @param family A description of the error distribution and link function to be used in the model.
+#'               This can be a character string naming a family function, a family function or
+#'               the result of a call to a family function. Refer R family at
+#'               \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#' @param epsilon Positive convergence tolerance of iterations.
+#' @param maxit Integer giving the maximal number of IRLS iterations.
+#' @return a fitted generalized linear model
+#' @rdname spark.glm
+#' @export
+#' @examples
+#' \dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' data(iris)
+#' df <- createDataFrame(sqlContext, iris)
+#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family="gaussian")
+#' summary(model)
+#' }
+setMethod(
+    "spark.glm",
+    signature(data = "SparkDataFrame", formula = "formula"),
+    function(data, formula, family = gaussian, epsilon = 1e-06, maxit = 25) {
+        if (is.character(family)) {
+            family <- get(family, mode = "function", envir = parent.frame())
+        }
+        if (is.function(family)) {
+            family <- family()
+        }
+        if (is.null(family$family)) {
+            print(family)
+            stop("'family' not recognized")
+        }
+
+        formula <- paste(deparse(formula), collapse = "")
+
+        jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
+        "fit", formula, data@sdf, family$family, family$link,
+        epsilon, as.integer(maxit))
+        return(new("GeneralizedLinearRegressionModel", jobj = jobj))
+})
+
+#' Fits a generalized linear model (R-compliant).
+#'
 #' Fits a generalized linear model, similarly to R's glm().
 #'
 #' @param formula A symbolic description of the model to be fitted. Currently only a few formula
@@ -64,23 +120,7 @@ setClass("KMeansModel", representation(jobj = "jobj"))
 #' }
 setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"),
           function(formula, family = gaussian, data, epsilon = 1e-06, maxit = 25) {
-            if (is.character(family)) {
-              family <- get(family, mode = "function", envir = parent.frame())
-            }
-            if (is.function(family)) {
-              family <- family()
-            }
-            if (is.null(family$family)) {
-              print(family)
-              stop("'family' not recognized")
-            }
-
-            formula <- paste(deparse(formula), collapse = "")
-
-            jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
-                                "fit", formula, data@sdf, family$family, family$link,
-                                epsilon, as.integer(maxit))
-            return(new("GeneralizedLinearRegressionModel", jobj = jobj))
+            spark.glm(data, formula, family, epsilon, maxit)
           })
 
 #' Get the summary of a generalized linear model
@@ -188,7 +228,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- naiveBayes(y ~ x, trainingData)
+#' model <- spark.naiveBayes(trainingData, y ~ x)
 #' predicted <- predict(model, testData)
 #' showDF(predicted)
 #'}
@@ -208,7 +248,7 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- naiveBayes(y ~ x, trainingData)
+#' model <- spark.naiveBayes(trainingData, y ~ x)
 #' summary(model)
 #'}
 setMethod("summary", signature(object = "NaiveBayesModel"),
@@ -230,23 +270,23 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
 #'
 #' Fit a k-means model, similarly to R's kmeans().
 #'
-#' @param x SparkDataFrame for training
-#' @param centers Number of centers
-#' @param iter.max Maximum iteration number
-#' @param algorithm Algorithm choosen to fit the model
+#' @param data SparkDataFrame for training
+#' @param k Number of centers
+#' @param maxIter Maximum iteration number
+#' @param initializationMode Algorithm choosen to fit the model
 #' @return A fitted k-means model
-#' @rdname kmeans
+#' @rdname spark.kmeans
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- kmeans(x, centers = 2, algorithm="random")
+#' model <- spark.kmeans(data, k = 2, initializationMode="random")
 #' }
-setMethod("kmeans", signature(x = "SparkDataFrame"),
-          function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
-            columnNames <- as.array(colnames(x))
-            algorithm <- match.arg(algorithm)
-            jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf,
-                                centers, iter.max, algorithm, columnNames)
+setMethod("spark.kmeans", signature(data = "SparkDataFrame"),
+          function(data, k, maxIter = 10, initializationMode = c("random", "k-means||")) {
+            columnNames <- as.array(colnames(data))
+            initializationMode <- match.arg(initializationMode)
+            jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf,
+                                k, maxIter, initializationMode, columnNames)
             return(new("KMeansModel", jobj = jobj))
          })
 
@@ -261,7 +301,7 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- kmeans(trainingData, 2)
+#' model <- spark.kmeans(trainingData, 2)
 #' fitted.model <- fitted(model)
 #' showDF(fitted.model)
 #'}
@@ -288,7 +328,7 @@ setMethod("fitted", signature(object = "KMeansModel"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- kmeans(trainingData, 2)
+#' model <- spark.kmeans(trainingData, 2)
 #' summary(model)
 #' }
 setMethod("summary", signature(object = "KMeansModel"),
@@ -322,7 +362,7 @@ setMethod("summary", signature(object = "KMeansModel"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- kmeans(trainingData, 2)
+#' model <- spark.kmeans(trainingData, 2)
 #' predicted <- predict(model, testData)
 #' showDF(predicted)
 #' }
@@ -333,30 +373,28 @@ setMethod("predict", signature(object = "KMeansModel"),
 
 #' Fit a Bernoulli naive Bayes model
 #'
-#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
-#' categorical features are supported. The input should be a SparkDataFrame of observations instead
-#' of a contingency table.
+#' Fit a Bernoulli naive Bayes model on a Spark DataFrame (only categorical data is supported).
 #'
+#' @param data SparkDataFrame for training
 #' @param object A symbolic description of the model to be fitted. Currently only a few formula
 #'               operators are supported, including '~', '.', ':', '+', and '-'.
-#' @param data SparkDataFrame for training
 #' @param laplace Smoothing parameter
 #' @return a fitted naive Bayes model
-#' @rdname naiveBayes
+#' @rdname spark.naiveBayes
 #' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
 #' @export
 #' @examples
 #' \dontrun{
 #' df <- createDataFrame(sqlContext, infert)
-#' model <- naiveBayes(education ~ ., df, laplace = 0)
+#' model <- spark.naiveBayes(df, education ~ ., laplace = 0)
 #'}
-setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
-          function(formula, data, laplace = 0, ...) {
-            formula <- paste(deparse(formula), collapse = "")
-            jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
-                                 formula, data@sdf, laplace)
-            return(new("NaiveBayesModel", jobj = jobj))
-          })
+setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
+    function(data, formula, laplace = 0, ...) {
+        formula <- paste(deparse(formula), collapse = "")
+        jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
+          formula, data@sdf, laplace)
+        return(new("NaiveBayesModel", jobj = jobj))
+    })
 
 #' Save the Bernoulli naive Bayes model to the input path.
 #'
@@ -371,7 +409,7 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
 #' @examples
 #' \dontrun{
 #' df <- createDataFrame(sqlContext, infert)
-#' model <- naiveBayes(education ~ ., df, laplace = 0)
+#' model <- spark.naiveBayes(education ~ ., df, laplace = 0)
 #' path <- "path/to/model"
 #' ml.save(model, path)
 #' }
@@ -396,7 +434,7 @@ setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
 #' path <- "path/to/model"
 #' ml.save(model, path)
 #' }
@@ -446,7 +484,7 @@ setMethod("ml.save", signature(object = "GeneralizedLinearRegressionModel", path
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- kmeans(x, centers = 2, algorithm="random")
+#' model <- spark.kmeans(x, k = 2, initializationMode="random")
 #' path <- "path/to/model"
 #' ml.save(model, path)
 #' }
@@ -489,29 +527,30 @@ ml.load <- function(path) {
 
 #' Fit an accelerated failure time (AFT) survival regression model.
 #'
-#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
+#' Fit an accelerated failure time (AFT) survival regression model on a Spark DataFrame.
 #'
+#' @param data SparkDataFrame for training.
 #' @param formula A symbolic description of the model to be fitted. Currently only a few formula
 #'                operators are supported, including '~', ':', '+', and '-'.
 #'                Note that operator '.' is not supported currently.
-#' @param data SparkDataFrame for training.
 #' @return a fitted AFT survival regression model
-#' @rdname survreg
+#' @rdname spark.survreg
 #' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
 #' @export
 #' @examples
 #' \dontrun{
 #' df <- createDataFrame(sqlContext, ovarian)
-#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df)
+#' model <- spark.survreg(Surv(df, futime, fustat) ~ ecog_ps + rx)
 #' }
-setMethod("survreg", signature(formula = "formula", data = "SparkDataFrame"),
-          function(formula, data, ...) {
+setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
+          function(data, formula, ...) {
             formula <- paste(deparse(formula), collapse = "")
             jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
                                 "fit", formula, data@sdf)
             return(new("AFTSurvivalRegressionModel", jobj = jobj))
           })
 
+
 #' Get the summary of an AFT survival regression model
 #'
 #' Returns the summary of an AFT survival regression model produced by survreg(),
@@ -523,7 +562,7 @@ setMethod("survreg", signature(formula = "formula", data = "SparkDataFrame"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
 #' summary(model)
 #' }
 setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
@@ -548,7 +587,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
 #' @export
 #' @examples
 #' \dontrun{
-#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
 #' predicted <- predict(model, testData)
 #' showDF(predicted)
 #' }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc36fe6e/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 6a822be..18a4e78 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -25,6 +25,137 @@ sc <- sparkR.init()
 
 sqlContext <- sparkRSQL.init(sc)
 
+test_that("formula of spark.glm", {
+  training <- suppressWarnings(createDataFrame(sqlContext, iris))
+  # directly calling the spark API
+  # dot minus and intercept vs native glm
+  model <- spark.glm(training, Sepal_Width ~ . - Species + 0)
+  vals <- collect(select(predict(model, training), "prediction"))
+  rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
+  expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+
+  # feature interaction vs native glm
+  model <- spark.glm(training, Sepal_Width ~ Species:Sepal_Length)
+  vals <- collect(select(predict(model, training), "prediction"))
+  rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
+  expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+
+  # glm should work with long formula
+  training <- suppressWarnings(createDataFrame(sqlContext, iris))
+  training$LongLongLongLongLongName <- training$Sepal_Width
+  training$VeryLongLongLongLonLongName <- training$Sepal_Length
+  training$AnotherLongLongLongLongName <- training$Species
+  model <- spark.glm(training, LongLongLongLongLongName ~ VeryLongLongLongLonLongName +
+    AnotherLongLongLongLongName)
+  vals <- collect(select(predict(model, training), "prediction"))
+  rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
+  expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+})
+
+test_that("spark.glm and predict", {
+  training <- suppressWarnings(createDataFrame(sqlContext, iris))
+  # gaussian family
+  model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species)
+  prediction <- predict(model, training)
+  expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+  vals <- collect(select(prediction, "prediction"))
+  rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
+  expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+
+  # poisson family
+  model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species,
+  family = poisson(link = identity))
+  prediction <- predict(model, training)
+  expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+  vals <- collect(select(prediction, "prediction"))
+  rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
+  data = iris, family = poisson(link = identity)), iris))
+  expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+
+  # Test stats::predict is working
+  x <- rnorm(15)
+  y <- x + rnorm(15)
+  expect_equal(length(predict(lm(y ~ x))), 15)
+})
+
+test_that("spark.glm summary", {
+  # gaussian family
+  training <- suppressWarnings(createDataFrame(sqlContext, iris))
+  stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species))
+
+  rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
+
+  coefs <- unlist(stats$coefficients)
+  rCoefs <- unlist(rStats$coefficients)
+  expect_true(all(abs(rCoefs - coefs) < 1e-4))
+  expect_true(all(
+    rownames(stats$coefficients) ==
+    c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
+  expect_equal(stats$dispersion, rStats$dispersion)
+  expect_equal(stats$null.deviance, rStats$null.deviance)
+  expect_equal(stats$deviance, rStats$deviance)
+  expect_equal(stats$df.null, rStats$df.null)
+  expect_equal(stats$df.residual, rStats$df.residual)
+  expect_equal(stats$aic, rStats$aic)
+
+  # binomial family
+  df <- suppressWarnings(createDataFrame(sqlContext, iris))
+  training <- df[df$Species %in% c("versicolor", "virginica"), ]
+  stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width,
+    family = binomial(link = "logit")))
+
+  rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
+  rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
+  family = binomial(link = "logit")))
+
+  coefs <- unlist(stats$coefficients)
+  rCoefs <- unlist(rStats$coefficients)
+  expect_true(all(abs(rCoefs - coefs) < 1e-4))
+  expect_true(all(
+    rownames(stats$coefficients) ==
+    c("(Intercept)", "Sepal_Length", "Sepal_Width")))
+  expect_equal(stats$dispersion, rStats$dispersion)
+  expect_equal(stats$null.deviance, rStats$null.deviance)
+  expect_equal(stats$deviance, rStats$deviance)
+  expect_equal(stats$df.null, rStats$df.null)
+  expect_equal(stats$df.residual, rStats$df.residual)
+  expect_equal(stats$aic, rStats$aic)
+
+  # Test summary works on base GLM models
+  baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
+  baseSummary <- summary(baseModel)
+  expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
+})
+
+test_that("spark.glm save/load", {
+  training <- suppressWarnings(createDataFrame(sqlContext, iris))
+  m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species)
+  s <- summary(m)
+
+  modelPath <- tempfile(pattern = "glm", fileext = ".tmp")
+  ml.save(m, modelPath)
+  expect_error(ml.save(m, modelPath))
+  ml.save(m, modelPath, overwrite = TRUE)
+  m2 <- ml.load(modelPath)
+  s2 <- summary(m2)
+
+  expect_equal(s$coefficients, s2$coefficients)
+  expect_equal(rownames(s$coefficients), rownames(s2$coefficients))
+  expect_equal(s$dispersion, s2$dispersion)
+  expect_equal(s$null.deviance, s2$null.deviance)
+  expect_equal(s$deviance, s2$deviance)
+  expect_equal(s$df.null, s2$df.null)
+  expect_equal(s$df.residual, s2$df.residual)
+  expect_equal(s$aic, s2$aic)
+  expect_equal(s$iter, s2$iter)
+  expect_true(!s$is.loaded)
+  expect_true(s2$is.loaded)
+
+  unlink(modelPath)
+})
+
+
+
 test_that("formula of glm", {
   training <- suppressWarnings(createDataFrame(sqlContext, iris))
   # dot minus and intercept vs native glm
@@ -153,14 +284,14 @@ test_that("glm save/load", {
   unlink(modelPath)
 })
 
-test_that("kmeans", {
+test_that("spark.kmeans", {
   newIris <- iris
   newIris$Species <- NULL
   training <- suppressWarnings(createDataFrame(sqlContext, newIris))
 
   take(training, 1)
 
-  model <- kmeans(x = training, centers = 2)
+  model <- spark.kmeans(data = training, k = 2)
   sample <- take(select(predict(model, training), "prediction"), 1)
   expect_equal(typeof(sample$prediction), "integer")
   expect_equal(sample$prediction, 1)
@@ -235,7 +366,7 @@ test_that("naiveBayes", {
   t <- as.data.frame(Titanic)
   t1 <- t[t$Freq > 0, -5]
   df <- suppressWarnings(createDataFrame(sqlContext, t1))
-  m <- naiveBayes(Survived ~ ., data = df)
+  m <- spark.naiveBayes(df, Survived ~ .)
   s <- summary(m)
   expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
   expect_equal(sum(s$apriori), 1)
@@ -264,7 +395,7 @@ test_that("naiveBayes", {
   }
 })
 
-test_that("survreg", {
+test_that("spark.survreg", {
   # R code to reproduce the result.
   #
   #' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
@@ -290,7 +421,7 @@ test_that("survreg", {
   data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0),
           list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1))
   df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex"))
-  model <- survreg(Surv(time, status) ~ x + sex, df)
+  model <- spark.survreg(df, Surv(time, status) ~ x + sex)
   stats <- summary(model)
   coefs <- as.vector(stats$coefficients[, 1])
   rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org