You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by fe...@apache.org on 2017/01/11 05:22:22 UTC
spark git commit: [SPARK-19133][SPARKR][ML][BACKPORT-2.1] fix glm for
Gamma, clarify glm family supported
Repository: spark
Updated Branches:
refs/heads/branch-2.1 230607d62 -> 1022049c7
[SPARK-19133][SPARKR][ML][BACKPORT-2.1] fix glm for Gamma, clarify glm family supported
## What changes were proposed in this pull request?
backporting to 2.1, 2.0 and 1.6
## How was this patch tested?
unit tests
Author: Felix Cheung <fe...@hotmail.com>
Closes #16532 from felixcheung/rgammabackport.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1022049c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1022049c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1022049c
Branch: refs/heads/branch-2.1
Commit: 1022049c78e55914c54dff6d5206ad56dba7eef4
Parents: 230607d
Author: Felix Cheung <fe...@hotmail.com>
Authored: Tue Jan 10 21:22:16 2017 -0800
Committer: Felix Cheung <fe...@apache.org>
Committed: Tue Jan 10 21:22:16 2017 -0800
----------------------------------------------------------------------
R/pkg/R/mllib.R | 7 ++++++-
R/pkg/inst/tests/testthat/test_mllib.R | 8 ++++++++
2 files changed, 14 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/1022049c/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index d736bbb..1a254ad 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -184,6 +184,8 @@ predict_internal <- function(object, newData) {
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#' Currently these families are supported: \code{binomial}, \code{gaussian},
+#' \code{Gamma}, and \code{poisson}.
#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
@@ -236,8 +238,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
weightCol <- ""
}
+ # For known families, Gamma is upper-cased
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
- "fit", formula, data@sdf, family$family, family$link,
+ "fit", formula, data@sdf, tolower(family$family), family$link,
tol, as.integer(maxIter), as.character(weightCol), regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
@@ -252,6 +255,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#' Currently these families are supported: \code{binomial}, \code{gaussian},
+#' \code{Gamma}, and \code{poisson}.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param epsilon positive convergence tolerance of iterations.
http://git-wip-us.apache.org/repos/asf/spark/blob/1022049c/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 40c0446..1f2fae9 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -74,6 +74,14 @@ test_that("spark.glm and predict", {
data = iris, family = poisson(link = identity)), iris))
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+ # Gamma family
+ x <- runif(100, -1, 1)
+ y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
+ df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
+ model <- glm(y ~ x, family = Gamma, df)
+ out <- capture.output(print(summary(model)))
+ expect_true(any(grepl("Dispersion parameter for gamma family", out)))
+
# Test stats::predict is working
x <- rnorm(15)
y <- x + rnorm(15)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org