You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by fe...@apache.org on 2017/08/01 03:37:10 UTC
spark git commit: [SPARK-21381][SPARKR] SparkR: pass on
setHandleInvalid for classification algorithms
Repository: spark
Updated Branches:
refs/heads/master 6b186c9d6 -> 9570e81aa
[SPARK-21381][SPARKR] SparkR: pass on setHandleInvalid for classification algorithms
## What changes were proposed in this pull request?
SPARK-20307 Added handleInvalid option to RFormula for tree-based classification algorithms. We should add this parameter for other classification algorithms in SparkR.
This is a followup PR for SPARK-20307.
## How was this patch tested?
New Unit tests are added.
Author: wangmiao1981 <wm...@hotmail.com>
Closes #18605 from wangmiao1981/class.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9570e81a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9570e81a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9570e81a
Branch: refs/heads/master
Commit: 9570e81aa949cddb30a0e94c92093cd16e34326a
Parents: 6b186c9
Author: wangmiao1981 <wm...@hotmail.com>
Authored: Mon Jul 31 20:37:06 2017 -0700
Committer: Felix Cheung <fe...@apache.org>
Committed: Mon Jul 31 20:37:06 2017 -0700
----------------------------------------------------------------------
R/pkg/R/mllib_classification.R | 49 ++++++++++++++---
R/pkg/R/mllib_tree.R | 33 ++++++++---
.../tests/fulltests/test_mllib_classification.R | 58 ++++++++++++++++++++
R/pkg/tests/fulltests/test_mllib_tree.R | 30 ++++++++++
.../r/DecisionTreeClassificationWrapper.scala | 4 +-
.../spark/ml/r/GBTClassificationWrapper.scala | 4 +-
.../apache/spark/ml/r/LinearSVCWrapper.scala | 4 +-
.../spark/ml/r/LogisticRegressionWrapper.scala | 4 +-
.../MultilayerPerceptronClassifierWrapper.scala | 6 +-
.../apache/spark/ml/r/NaiveBayesWrapper.scala | 7 ++-
10 files changed, 175 insertions(+), 24 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/R/pkg/R/mllib_classification.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index 82d2428..15af829 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -69,6 +69,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
#' or the number of partitions are large, this param could be adjusted to a larger size.
#' This is an expert parameter. Default value should be good for most cases.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#' column of string type.
+#' Supported options: "skip" (filter out rows with invalid data),
+#' "error" (throw an error), "keep" (put invalid data in a special additional
+#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @return \code{spark.svmLinear} returns a fitted linear SVM model.
#' @rdname spark.svmLinear
@@ -98,7 +103,8 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' @note spark.svmLinear since 2.2.0
setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE,
- threshold = 0.0, weightCol = NULL, aggregationDepth = 2) {
+ threshold = 0.0, weightCol = NULL, aggregationDepth = 2,
+ handleInvalid = c("error", "keep", "skip")) {
formula <- paste(deparse(formula), collapse = "")
if (!is.null(weightCol) && weightCol == "") {
@@ -107,10 +113,12 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
weightCol <- as.character(weightCol)
}
+ handleInvalid <- match.arg(handleInvalid)
+
jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit",
data@sdf, formula, as.numeric(regParam), as.integer(maxIter),
as.numeric(tol), as.logical(standardization), as.numeric(threshold),
- weightCol, as.integer(aggregationDepth))
+ weightCol, as.integer(aggregationDepth), handleInvalid)
new("LinearSVCModel", jobj = jobj)
})
@@ -218,6 +226,11 @@ function(object, path, overwrite = FALSE) {
#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization.
#' The bound vector size must be equal to 1 for binomial regression, or the number
#' of classes for multinomial regression.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#' column of string type.
+#' Supported options: "skip" (filter out rows with invalid data),
+#' "error" (throw an error), "keep" (put invalid data in a special additional
+#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @return \code{spark.logit} returns a fitted logistic regression model.
#' @rdname spark.logit
@@ -257,7 +270,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
tol = 1E-6, family = "auto", standardization = TRUE,
thresholds = 0.5, weightCol = NULL, aggregationDepth = 2,
lowerBoundsOnCoefficients = NULL, upperBoundsOnCoefficients = NULL,
- lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL) {
+ lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL,
+ handleInvalid = c("error", "keep", "skip")) {
formula <- paste(deparse(formula), collapse = "")
row <- 0
col <- 0
@@ -304,6 +318,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
upperBoundsOnCoefficients <- as.array(as.vector(upperBoundsOnCoefficients))
}
+ handleInvalid <- match.arg(handleInvalid)
+
jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
data@sdf, formula, as.numeric(regParam),
as.numeric(elasticNetParam), as.integer(maxIter),
@@ -312,7 +328,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
weightCol, as.integer(aggregationDepth),
as.integer(row), as.integer(col),
lowerBoundsOnCoefficients, upperBoundsOnCoefficients,
- lowerBoundsOnIntercepts, upperBoundsOnIntercepts)
+ lowerBoundsOnIntercepts, upperBoundsOnIntercepts,
+ handleInvalid)
new("LogisticRegressionModel", jobj = jobj)
})
@@ -394,7 +411,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char
#' @param stepSize stepSize parameter.
#' @param seed seed parameter for weights initialization.
#' @param initialWeights initialWeights parameter for weights initialization, it should be a
-#' numeric vector.
+#' numeric vector.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#' column of string type.
+#' Supported options: "skip" (filter out rows with invalid data),
+#' "error" (throw an error), "keep" (put invalid data in a special additional
+#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model.
#' @rdname spark.mlp
@@ -426,7 +448,8 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char
#' @note spark.mlp since 2.1.0
setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100,
- tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) {
+ tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL,
+ handleInvalid = c("error", "keep", "skip")) {
formula <- paste(deparse(formula), collapse = "")
if (is.null(layers)) {
stop ("layers must be a integer vector with length > 1.")
@@ -441,10 +464,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"),
if (!is.null(initialWeights)) {
initialWeights <- as.array(as.numeric(na.omit(initialWeights)))
}
+ handleInvalid <- match.arg(handleInvalid)
jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper",
"fit", data@sdf, formula, as.integer(blockSize), as.array(layers),
as.character(solver), as.integer(maxIter), as.numeric(tol),
- as.numeric(stepSize), seed, initialWeights)
+ as.numeric(stepSize), seed, initialWeights, handleInvalid)
new("MultilayerPerceptronClassificationModel", jobj = jobj)
})
@@ -514,6 +538,11 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param smoothing smoothing parameter.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#' column of string type.
+#' Supported options: "skip" (filter out rows with invalid data),
+#' "error" (throw an error), "keep" (put invalid data in a special additional
+#' bucket, at index numLabels). Default is "error".
#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}.
#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model.
#' @rdname spark.naiveBayes
@@ -543,10 +572,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode
#' }
#' @note spark.naiveBayes since 2.0.0
setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula, smoothing = 1.0) {
+ function(data, formula, smoothing = 1.0,
+ handleInvalid = c("error", "keep", "skip")) {
formula <- paste(deparse(formula), collapse = "")
+ handleInvalid <- match.arg(handleInvalid)
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
- formula, data@sdf, smoothing)
+ formula, data@sdf, smoothing, handleInvalid)
new("NaiveBayesModel", jobj = jobj)
})
http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/R/pkg/R/mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R
index 75b1a74..33c4653 100644
--- a/R/pkg/R/mllib_tree.R
+++ b/R/pkg/R/mllib_tree.R
@@ -164,6 +164,11 @@ print.summary.decisionTree <- function(x) {
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#' column of string type in classification model.
+#' Supported options: "skip" (filter out rows with invalid data),
+#' "error" (throw an error), "keep" (put invalid data in a special additional
+#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @aliases spark.gbt,SparkDataFrame,formula-method
#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model.
@@ -205,7 +210,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, type = c("regression", "classification"),
maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL,
seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0,
- checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) {
+ checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE,
+ handleInvalid = c("error", "keep", "skip")) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
@@ -225,6 +231,7 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
new("GBTRegressionModel", jobj = jobj)
},
classification = {
+ handleInvalid <- match.arg(handleInvalid)
if (is.null(lossType)) lossType <- "logistic"
lossType <- match.arg(lossType, "logistic")
jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper",
@@ -233,7 +240,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
as.numeric(stepSize), as.integer(minInstancesPerNode),
as.numeric(minInfoGain), as.integer(checkpointInterval),
lossType, seed, as.numeric(subsamplingRate),
- as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
+ as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
+ handleInvalid)
new("GBTClassificationModel", jobj = jobj)
}
)
@@ -374,10 +382,11 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
-#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model.
-#' Supported options: "skip" (filter out rows with invalid data),
-#' "error" (throw an error), "keep" (put invalid data in a special additional
-#' bucket, at index numLabels). Default is "error".
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#' column of string type in classification model.
+#' Supported options: "skip" (filter out rows with invalid data),
+#' "error" (throw an error), "keep" (put invalid data in a special additional
+#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @aliases spark.randomForest,SparkDataFrame,formula-method
#' @return \code{spark.randomForest} returns a fitted Random Forest model.
@@ -583,6 +592,11 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
+#' column of string type in classification model.
+#' Supported options: "skip" (filter out rows with invalid data),
+#' "error" (throw an error), "keep" (put invalid data in a special additional
+#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @aliases spark.decisionTree,SparkDataFrame,formula-method
#' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
@@ -617,7 +631,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
function(data, formula, type = c("regression", "classification"),
maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
- maxMemoryInMB = 256, cacheNodeIds = FALSE) {
+ maxMemoryInMB = 256, cacheNodeIds = FALSE,
+ handleInvalid = c("error", "keep", "skip")) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
@@ -636,6 +651,7 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
new("DecisionTreeRegressionModel", jobj = jobj)
},
classification = {
+ handleInvalid <- match.arg(handleInvalid)
if (is.null(impurity)) impurity <- "gini"
impurity <- match.arg(impurity, c("gini", "entropy"))
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
@@ -643,7 +659,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
as.integer(maxBins), impurity,
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
as.integer(checkpointInterval), seed,
- as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
+ as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
+ handleInvalid)
new("DecisionTreeClassificationModel", jobj = jobj)
}
)
http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/R/pkg/tests/fulltests/test_mllib_classification.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R
index 3d75f4c..a4d0397 100644
--- a/R/pkg/tests/fulltests/test_mllib_classification.R
+++ b/R/pkg/tests/fulltests/test_mllib_classification.R
@@ -70,6 +70,20 @@ test_that("spark.svmLinear", {
prediction <- collect(select(predict(model, df), "prediction"))
expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))
+ # Test unseen labels
+ data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+ someString = base::sample(c("this", "that"), 10, replace = TRUE),
+ stringsAsFactors = FALSE)
+ trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+ traindf <- as.DataFrame(data[trainidxs, ])
+ testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+ model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1)
+ predictions <- predict(model, testdf)
+ expect_error(collect(predictions))
+ model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, handleInvalid = "skip")
+ predictions <- predict(model, testdf)
+ expect_equal(class(collect(predictions)$clicked[1]), "list")
+
})
test_that("spark.logit", {
@@ -263,6 +277,21 @@ test_that("spark.logit", {
virginicaCoefs <- summary$coefficients[, "virginica"]
expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1))
expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1))
+
+ # Test unseen labels
+ data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+ someString = base::sample(c("this", "that"), 10, replace = TRUE),
+ stringsAsFactors = FALSE)
+ trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+ traindf <- as.DataFrame(data[trainidxs, ])
+ testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+ model <- spark.logit(traindf, clicked ~ ., regParam = 0.5)
+ predictions <- predict(model, testdf)
+ expect_error(collect(predictions))
+ model <- spark.logit(traindf, clicked ~ ., regParam = 0.5, handleInvalid = "keep")
+ predictions <- predict(model, testdf)
+ expect_equal(class(collect(predictions)$clicked[1]), "character")
+
})
test_that("spark.mlp", {
@@ -344,6 +373,21 @@ test_that("spark.mlp", {
expect_equal(summary$numOfOutputs, 3)
expect_equal(summary$layers, c(4, 3))
expect_equal(length(summary$weights), 15)
+
+ # Test unseen labels
+ data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+ someString = base::sample(c("this", "that"), 10, replace = TRUE),
+ stringsAsFactors = FALSE)
+ trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+ traindf <- as.DataFrame(data[trainidxs, ])
+ testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+ model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3))
+ predictions <- predict(model, testdf)
+ expect_error(collect(predictions))
+ model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip")
+ predictions <- predict(model, testdf)
+ expect_equal(class(collect(predictions)$clicked[1]), "list")
+
})
test_that("spark.naiveBayes", {
@@ -427,6 +471,20 @@ test_that("spark.naiveBayes", {
expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6)
expect_equal(sum(s$apriori), 1)
expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6)
+
+ # Test unseen labels
+ data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+ someString = base::sample(c("this", "that"), 10, replace = TRUE),
+ stringsAsFactors = FALSE)
+ trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+ traindf <- as.DataFrame(data[trainidxs, ])
+ testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+ model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0)
+ predictions <- predict(model, testdf)
+ expect_error(collect(predictions))
+ model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0, handleInvalid = "keep")
+ predictions <- predict(model, testdf)
+ expect_equal(class(collect(predictions)$clicked[1]), "character")
})
sparkR.session.stop()
http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/R/pkg/tests/fulltests/test_mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R
index e31a65f..799f944 100644
--- a/R/pkg/tests/fulltests/test_mllib_tree.R
+++ b/R/pkg/tests/fulltests/test_mllib_tree.R
@@ -109,6 +109,20 @@ test_that("spark.gbt", {
model <- spark.gbt(data, label ~ features, "classification")
expect_equal(summary(model)$numFeatures, 692)
}
+
+ # Test unseen labels
+ data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+ someString = base::sample(c("this", "that"), 10, replace = TRUE),
+ stringsAsFactors = FALSE)
+ trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+ traindf <- as.DataFrame(data[trainidxs, ])
+ testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+ model <- spark.gbt(traindf, clicked ~ ., type = "classification")
+ predictions <- predict(model, testdf)
+ expect_error(collect(predictions))
+ model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep")
+ predictions <- predict(model, testdf)
+ expect_equal(class(collect(predictions)$clicked[1]), "character")
})
test_that("spark.randomForest", {
@@ -328,6 +342,22 @@ test_that("spark.decisionTree", {
model <- spark.decisionTree(data, label ~ features, "classification")
expect_equal(summary(model)$numFeatures, 4)
}
+
+ # Test unseen labels
+ data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+ someString = base::sample(c("this", "that"), 10, replace = TRUE),
+ stringsAsFactors = FALSE)
+ trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+ traindf <- as.DataFrame(data[trainidxs, ])
+ testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+ model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
+ maxDepth = 5, maxBins = 16)
+ predictions <- predict(model, testdf)
+ expect_error(collect(predictions))
+ model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
+ maxDepth = 5, maxBins = 16, handleInvalid = "keep")
+ predictions <- predict(model, testdf)
+ expect_equal(class(collect(predictions)$clicked[1]), "character")
})
sparkR.session.stop()
http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
index 7f59825..a90cae5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
@@ -73,11 +73,13 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC
checkpointInterval: Int,
seed: String,
maxMemoryInMB: Int,
- cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = {
+ cacheNodeIds: Boolean,
+ handleInvalid: String): DecisionTreeClassifierWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
+ .setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
index c07eadb..ecaeec5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
@@ -78,11 +78,13 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
seed: String,
subsamplingRate: Double,
maxMemoryInMB: Int,
- cacheNodeIds: Boolean): GBTClassifierWrapper = {
+ cacheNodeIds: Boolean,
+ handleInvalid: String): GBTClassifierWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
+ .setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
index 0dd1f11..7a22a71 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
@@ -79,12 +79,14 @@ private[r] object LinearSVCWrapper
standardization: Boolean,
threshold: Double,
weightCol: String,
- aggregationDepth: Int
+ aggregationDepth: Int,
+ handleInvalid: String
): LinearSVCWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
+ .setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
index b96481a..18acf7d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
@@ -103,12 +103,14 @@ private[r] object LogisticRegressionWrapper
lowerBoundsOnCoefficients: Array[Double],
upperBoundsOnCoefficients: Array[Double],
lowerBoundsOnIntercepts: Array[Double],
- upperBoundsOnIntercepts: Array[Double]
+ upperBoundsOnIntercepts: Array[Double],
+ handleInvalid: String
): LogisticRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
+ .setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
index 48c8774..62f6421 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
@@ -62,7 +62,7 @@ private[r] object MultilayerPerceptronClassifierWrapper
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
val PREDICTED_LABEL_COL = "prediction"
- def fit(
+ def fit( // scalastyle:ignore
data: DataFrame,
formula: String,
blockSize: Int,
@@ -72,11 +72,13 @@ private[r] object MultilayerPerceptronClassifierWrapper
tol: Double,
stepSize: Double,
seed: String,
- initialWeights: Array[Double]
+ initialWeights: Array[Double],
+ handleInvalid: String
): MultilayerPerceptronClassifierWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
+ .setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
http://git-wip-us.apache.org/repos/asf/spark/blob/9570e81a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index 0afea4b..fbf9f46 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -57,10 +57,15 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
val PREDICTED_LABEL_COL = "prediction"
- def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = {
+ def fit(
+ formula: String,
+ data: DataFrame,
+ smoothing: Double,
+ handleInvalid: String): NaiveBayesWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
+ .setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org