You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by fe...@apache.org on 2017/01/12 04:01:17 UTC

spark git commit: [SPARK-19133][SPARKR][ML][BACKPORT-2.0] fix glm for Gamma, clarify glm family supported

Repository: spark
Updated Branches:
  refs/heads/branch-2.0 6fe676c09 -> ec2fe925c


[SPARK-19133][SPARKR][ML][BACKPORT-2.0] fix glm for Gamma, clarify glm family supported

## What changes were proposed in this pull request?

Backport to 2.0 (cherry picking from 2.1 didn't work)

## How was this patch tested?

unit test

Author: Felix Cheung <fe...@hotmail.com>

Closes #16543 from felixcheung/rgammabackport20.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ec2fe925
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ec2fe925
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ec2fe925

Branch: refs/heads/branch-2.0
Commit: ec2fe925cd359ca5c132372d4b18ff791b70605a
Parents: 6fe676c
Author: Felix Cheung <fe...@hotmail.com>
Authored: Wed Jan 11 20:01:11 2017 -0800
Committer: Felix Cheung <fe...@apache.org>
Committed: Wed Jan 11 20:01:11 2017 -0800

----------------------------------------------------------------------
 R/pkg/R/mllib.R                        | 7 ++++++-
 R/pkg/inst/tests/testthat/test_mllib.R | 8 ++++++++
 2 files changed, 14 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ec2fe925/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index b33a16a..cd07f27 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -89,6 +89,8 @@ NULL
 #'               This can be a character string naming a family function, a family function or
 #'               the result of a call to a family function. Refer R family at
 #'               \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#'               Currently these families are supported: \code{binomial}, \code{gaussian},
+#'               \code{Gamma}, and \code{poisson}.
 #' @param tol positive convergence tolerance of iterations.
 #' @param maxIter integer giving the maximal number of IRLS iterations.
 #' @param ... additional arguments passed to the method.
@@ -134,8 +136,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
 
             formula <- paste(deparse(formula), collapse = "")
 
+            # For known families, Gamma is upper-cased
             jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
-                                "fit", formula, data@sdf, family$family, family$link,
+                                "fit", formula, data@sdf, tolower(family$family), family$link,
                                 tol, as.integer(maxIter))
             return(new("GeneralizedLinearRegressionModel", jobj = jobj))
           })
@@ -150,6 +153,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
 #'               This can be a character string naming a family function, a family function or
 #'               the result of a call to a family function. Refer R family at
 #'               \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#'               Currently these families are supported: \code{binomial}, \code{gaussian},
+#'               \code{Gamma}, and \code{poisson}.
 #' @param epsilon positive convergence tolerance of iterations.
 #' @param maxit integer giving the maximal number of IRLS iterations.
 #' @return \code{glm} returns a fitted generalized linear model.

http://git-wip-us.apache.org/repos/asf/spark/blob/ec2fe925/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 753da81..e0d2e53 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -69,6 +69,14 @@ test_that("spark.glm and predict", {
   data = iris, family = poisson(link = identity)), iris))
   expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
 
+  # Gamma family
+  x <- runif(100, -1, 1)
+  y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
+  df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
+  model <- glm(y ~ x, family = Gamma, df)
+  out <- capture.output(print(summary(model)))
+  expect_true(any(grepl("Dispersion parameter for gamma family", out)))
+
   # Test stats::predict is working
   x <- rnorm(15)
   y <- x + rnorm(15)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org