You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2020/02/16 07:59:45 UTC

[GitHub] [spark] huaxingao commented on a change in pull request #27571: [SPARK-30819][SPARKR][ML] Add FMRegressor wrapper to SparkR

huaxingao commented on a change in pull request #27571: [SPARK-30819][SPARKR][ML]  Add FMRegressor wrapper to SparkR
URL: https://github.com/apache/spark/pull/27571#discussion_r379882358
 
 

 ##########
 File path: R/pkg/R/mllib_regression.R
 ##########
 @@ -540,3 +546,150 @@ setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "c
           function(object, path, overwrite = FALSE) {
             write_internal(object, path, overwrite)
           })
+
+
+#' Factorization Machines Regression Model Model
+#'
+#' \code{spark.fmRegressor} fits a factorization regression model against a SparkDataFrame.
+#' Users can call \code{predict} to make
+#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
+#'
+#' @param data a \code{SparkDataFrame} of observations and labels for model fitting.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
+#'                operators are supported, including '~', '.', ':', '+', and '-'.
+#' @param factorSize dimensionality of the factors.
+#' @param fitLinear whether to fit linear term.  # TODO Can we express this with formula?
+#' @param regParam the regularization parameter.
+#' @param miniBatchFraction the mini-batch fraction parameter.
+#' @param initStd the standard deviation of initial coefficients.
+#' @param maxIter maximum iteration number.
+#' @param stepSize stepSize parameter.
+#' @param tol convergence tolerance of iterations.
+#' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "adamW".
+#' @param seed seed parameter for weights initialization.
+#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to
+#'                               decide the base level of a string feature as the last category
+#'                               after ordering is dropped when encoding strings. Supported options
+#'                               are "frequencyDesc", "frequencyAsc", "alphabetDesc", and
+#'                               "alphabetAsc". The default value is "frequencyDesc". When the
+#'                               ordering is set to "alphabetDesc", this drops the same category
+#'                               as R when encoding strings.
+#' @param ... additional arguments passed to the method.
+#' @return \code{spark.fmRegressor} returns a fitted Factorization Machines Regression Model.
+#'
+#' @rdname spark.fmRegressor
+#' @aliases spark.fmRegressor,SparkDataFrame,formula-method
+#' @name spark.fmRegressor
+#' @seealso \link{read.ml}
+#' @examples
+#' \dontrun{
+#' df <- read.df("data/mllib/sample_linear_regression_data.txt", source = "libsvm")
+#'
+#' # fit Factorization Machines Regression Model
+#' model <- spark.fmRegressor(
+#'            df, label ~ features,
+#'            regParam = 0.01, maxIter = 10, fitLinear = TRUE
+#'          )
+#'
+#' # get the summary of the model
+#' summary(model)
+#'
+#' # make predictions
+#' predictions <- predict(model, df)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.fmRegressor since 3.1.0
+setMethod("spark.fmRegressor", signature(data = "SparkDataFrame", formula = "formula"),
+          function(data, formula, factorSize = 8, fitLinear = TRUE, regParam = 0.0,
+                   miniBatchFraction = 1.0, initStd = 0.01, maxIter = 100, stepSize=1.0,
+                   tol = 1e-6, solver = c("adamW", "gd"), seed = NULL,
+                   stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
+                                              "alphabetDesc", "alphabetAsc")) {
+
+            formula <- paste(deparse(formula), collapse = "")
+
+            if (!is.null(seed)) {
+              seed <- as.character(as.integer(seed))
+            }
+
+            solver <- match.arg(solver)
+            stringIndexerOrderType <- match.arg(stringIndexerOrderType)
+
+            jobj <- callJStatic("org.apache.spark.ml.r.FMRegressorWrapper",
+                                "fit",
+                                data@sdf,
+                                formula,
+                                as.integer(factorSize),
+                                as.logical(fitLinear),
+                                as.numeric(regParam),
+                                as.numeric(miniBatchFraction),
+                                as.numeric(initStd),
+                                as.integer(maxIter),
+                                as.numeric(stepSize),
+                                as.numeric(tol),
+                                solver,
+                                seed,
+                                stringIndexerOrderType)
+            new("FMRegressionModel", jobj = jobj)
+          })
+
+
+#  Returns the summary of a FM Regression model produced by \code{spark.fmRegressor}
+
+#' @param object a FM Regression Model model fitted by \code{spark.fmRegressor}.
+#' @return \code{summary} returns summary information of the fitted model, which is a list.
+#'
+#' @rdname spark.fmRegressor
+#' @note summary(FMRegressionModel) since 3.1.0
+setMethod("summary", signature(object = "FMRegressionModel"),
+          function(object) {
+            jobj <- object@jobj
+            features <- callJMethod(jobj, "rFeatures")
+            coefficients <- callJMethod(jobj, "rCoefficients")
+            coefficients <- as.matrix(unlist(coefficients))
+            colnames(coefficients) <- c("Estimate")
+            rownames(coefficients) <- unlist(features)
+            numFeatures <- callJMethod(jobj, "numFeatures")
+            raw_factors <- unlist(callJMethod(jobj, "rFactors"))
+            factor_size <- callJMethod(jobj, "factorSize")
+
+            list(
+              coefficients = coefficients,
+              factors = matrix(raw_factors, ncol = factor_size),
+              numFeatures = numFeatures,
+              factorSize = factor_size
+            )
+          })
+
+
 
 Review comment:
   diddo

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org