You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by fe...@apache.org on 2017/01/11 05:22:22 UTC

spark git commit: [SPARK-19133][SPARKR][ML][BACKPORT-2.1] fix glm for Gamma, clarify glm family supported

Repository: spark
Updated Branches:
  refs/heads/branch-2.1 230607d62 -> 1022049c7


[SPARK-19133][SPARKR][ML][BACKPORT-2.1] fix glm for Gamma, clarify glm family supported

## What changes were proposed in this pull request?

backporting to 2.1, 2.0 and 1.6

## How was this patch tested?

unit tests

Author: Felix Cheung <fe...@hotmail.com>

Closes #16532 from felixcheung/rgammabackport.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1022049c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1022049c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1022049c

Branch: refs/heads/branch-2.1
Commit: 1022049c78e55914c54dff6d5206ad56dba7eef4
Parents: 230607d
Author: Felix Cheung <fe...@hotmail.com>
Authored: Tue Jan 10 21:22:16 2017 -0800
Committer: Felix Cheung <fe...@apache.org>
Committed: Tue Jan 10 21:22:16 2017 -0800

----------------------------------------------------------------------
 R/pkg/R/mllib.R                        | 7 ++++++-
 R/pkg/inst/tests/testthat/test_mllib.R | 8 ++++++++
 2 files changed, 14 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1022049c/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index d736bbb..1a254ad 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -184,6 +184,8 @@ predict_internal <- function(object, newData) {
 #'               This can be a character string naming a family function, a family function or
 #'               the result of a call to a family function. Refer R family at
 #'               \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#'               Currently these families are supported: \code{binomial}, \code{gaussian},
+#'               \code{Gamma}, and \code{poisson}.
 #' @param tol positive convergence tolerance of iterations.
 #' @param maxIter integer giving the maximal number of IRLS iterations.
 #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
@@ -236,8 +238,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
               weightCol <- ""
             }
 
+            # For known families, Gamma is upper-cased
             jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
-                                "fit", formula, data@sdf, family$family, family$link,
+                                "fit", formula, data@sdf, tolower(family$family), family$link,
                                 tol, as.integer(maxIter), as.character(weightCol), regParam)
             new("GeneralizedLinearRegressionModel", jobj = jobj)
           })
@@ -252,6 +255,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
 #'               This can be a character string naming a family function, a family function or
 #'               the result of a call to a family function. Refer R family at
 #'               \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#'               Currently these families are supported: \code{binomial}, \code{gaussian},
+#'               \code{Gamma}, and \code{poisson}.
 #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
 #'                  weights as 1.0.
 #' @param epsilon positive convergence tolerance of iterations.

http://git-wip-us.apache.org/repos/asf/spark/blob/1022049c/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 40c0446..1f2fae9 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -74,6 +74,14 @@ test_that("spark.glm and predict", {
   data = iris, family = poisson(link = identity)), iris))
   expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
 
+  # Gamma family
+  x <- runif(100, -1, 1)
+  y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
+  df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
+  model <- glm(y ~ x, family = Gamma, df)
+  out <- capture.output(print(summary(model)))
+  expect_true(any(grepl("Dispersion parameter for gamma family", out)))
+
   # Test stats::predict is working
   x <- rnorm(15)
   y <- x + rnorm(15)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org