You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/10/11 19:41:41 UTC

spark git commit: [SPARK-15153][ML][SPARKR] Fix SparkR spark.naiveBayes error when label is numeric type

Repository: spark
Updated Branches:
  refs/heads/master 07508bd01 -> 23405f324


[SPARK-15153][ML][SPARKR] Fix SparkR spark.naiveBayes error when label is numeric type

## What changes were proposed in this pull request?
Fix SparkR ```spark.naiveBayes``` error when response variable of dataset is numeric type.
See details and how to reproduce this bug at [SPARK-15153](https://issues.apache.org/jira/browse/SPARK-15153).

## How was this patch tested?
Add unit test.

Author: Yanbo Liang <yb...@gmail.com>

Closes #15431 from yanboliang/spark-15153-2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/23405f32
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/23405f32
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/23405f32

Branch: refs/heads/master
Commit: 23405f324a8089f86ebcbede9bb32944137508e8
Parents: 07508bd
Author: Yanbo Liang <yb...@gmail.com>
Authored: Tue Oct 11 12:41:35 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Oct 11 12:41:35 2016 -0700

----------------------------------------------------------------------
 R/pkg/inst/tests/testthat/test_mllib.R                    | 10 ++++++++++
 .../scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala   |  1 +
 2 files changed, 11 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/23405f32/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index a1eaaf2..c993157 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -481,6 +481,16 @@ test_that("spark.naiveBayes", {
     expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA)
     expect_equal(as.character(predict(m, t1[1, ])), "Yes")
   }
+
+  # Test numeric response variable
+  t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1)
+  t2 <- t1[-4]
+  df <- suppressWarnings(createDataFrame(t2))
+  m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0)
+  s <- summary(m)
+  expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6)
+  expect_equal(sum(s$apriori), 1)
+  expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6)
 })
 
 test_that("spark.survreg", {

http://git-wip-us.apache.org/repos/asf/spark/blob/23405f32/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index d1a39fe..4fdab2d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -59,6 +59,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
   def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = {
     val rFormula = new RFormula()
       .setFormula(formula)
+      .setForceIndexLabel(true)
     RWrapperUtils.checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
     // get labels and feature names from output schema


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org