You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/10/11 19:41:41 UTC
spark git commit: [SPARK-15153][ML][SPARKR] Fix SparkR
spark.naiveBayes error when label is numeric type
Repository: spark
Updated Branches:
refs/heads/master 07508bd01 -> 23405f324
[SPARK-15153][ML][SPARKR] Fix SparkR spark.naiveBayes error when label is numeric type
## What changes were proposed in this pull request?
Fix SparkR ```spark.naiveBayes``` error when response variable of dataset is numeric type.
See details and how to reproduce this bug at [SPARK-15153](https://issues.apache.org/jira/browse/SPARK-15153).
## How was this patch tested?
Add unit test.
Author: Yanbo Liang <yb...@gmail.com>
Closes #15431 from yanboliang/spark-15153-2.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/23405f32
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/23405f32
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/23405f32
Branch: refs/heads/master
Commit: 23405f324a8089f86ebcbede9bb32944137508e8
Parents: 07508bd
Author: Yanbo Liang <yb...@gmail.com>
Authored: Tue Oct 11 12:41:35 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Oct 11 12:41:35 2016 -0700
----------------------------------------------------------------------
R/pkg/inst/tests/testthat/test_mllib.R | 10 ++++++++++
.../scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala | 1 +
2 files changed, 11 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/23405f32/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index a1eaaf2..c993157 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -481,6 +481,16 @@ test_that("spark.naiveBayes", {
expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA)
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
}
+
+ # Test numeric response variable
+ t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1)
+ t2 <- t1[-4]
+ df <- suppressWarnings(createDataFrame(t2))
+ m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0)
+ s <- summary(m)
+ expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6)
+ expect_equal(sum(s$apriori), 1)
+ expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6)
})
test_that("spark.survreg", {
http://git-wip-us.apache.org/repos/asf/spark/blob/23405f32/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index d1a39fe..4fdab2d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -59,6 +59,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
+ .setForceIndexLabel(true)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org