You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/03/22 22:16:56 UTC

spark git commit: [SPARK-13449] Naive Bayes wrapper in SparkR

Repository: spark
Updated Branches:
  refs/heads/master b2b1ad7d4 -> d6dc12ef0


[SPARK-13449] Naive Bayes wrapper in SparkR

## What changes were proposed in this pull request?

This PR continues the work in #11486 from yinxusen with some code refactoring. In R package e1071, `naiveBayes` supports both categorical (Bernoulli) and continuous features (Gaussian), while in MLlib we support Bernoulli and multinomial. This PR implements the common subset: Bernoulli.

I moved the implementation out from SparkRWrappers to NaiveBayesWrapper to make it easier to read. Argument names, default values, and summary now match e1071's naiveBayes.

I removed the preprocess part that omit NA values because we don't know which columns to process.

## How was this patch tested?

Test against output from R package e1071's naiveBayes.

cc: yanboliang yinxusen

Closes #11486

Author: Xusen Yin <yi...@gmail.com>
Author: Xiangrui Meng <me...@databricks.com>

Closes #11890 from mengxr/SPARK-13449.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d6dc12ef
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d6dc12ef
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d6dc12ef

Branch: refs/heads/master
Commit: d6dc12ef0146ae409834c78737c116050961f350
Parents: b2b1ad7
Author: Xusen Yin <yi...@gmail.com>
Authored: Tue Mar 22 14:16:51 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Mar 22 14:16:51 2016 -0700

----------------------------------------------------------------------
 R/pkg/DESCRIPTION                               |  3 +-
 R/pkg/NAMESPACE                                 |  3 +-
 R/pkg/R/generics.R                              |  4 +
 R/pkg/R/mllib.R                                 | 91 ++++++++++++++++++--
 R/pkg/inst/tests/testthat/test_mllib.R          | 59 +++++++++++++
 .../apache/spark/ml/r/NaiveBayesWrapper.scala   | 75 ++++++++++++++++
 6 files changed, 228 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d6dc12ef/R/pkg/DESCRIPTION
----------------------------------------------------------------------
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 0cd0d75..e26f9a7 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -11,7 +11,8 @@ Depends:
     R (>= 3.0),
     methods,
 Suggests:
-    testthat
+    testthat,
+    e1071
 Description: R frontend for Spark
 License: Apache License (== 2.0)
 Collate:

http://git-wip-us.apache.org/repos/asf/spark/blob/d6dc12ef/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 636d39e..5d8a4b1 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -15,7 +15,8 @@ exportMethods("glm",
               "predict",
               "summary",
               "kmeans",
-              "fitted")
+              "fitted",
+              "naiveBayes")
 
 # Job group lifecycle management methods
 export("setJobGroup",

http://git-wip-us.apache.org/repos/asf/spark/blob/d6dc12ef/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 6ad71fc..46b115f 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1175,3 +1175,7 @@ setGeneric("kmeans")
 #' @rdname fitted
 #' @export
 setGeneric("fitted")
+
+#' @rdname naiveBayes
+#' @export
+setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })

http://git-wip-us.apache.org/repos/asf/spark/blob/d6dc12ef/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 5c0d3dc..2555019 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -22,6 +22,11 @@
 #' @export
 setClass("PipelineModel", representation(model = "jobj"))
 
+#' @title S4 class that represents a NaiveBayesModel
+#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
+#' @export
+setClass("NaiveBayesModel", representation(jobj = "jobj"))
+
 #' Fits a generalized linear model
 #'
 #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -42,7 +47,7 @@ setClass("PipelineModel", representation(model = "jobj"))
 #' @rdname glm
 #' @export
 #' @examples
-#'\dontrun{
+#' \dontrun{
 #' sc <- sparkR.init()
 #' sqlContext <- sparkRSQL.init(sc)
 #' data(iris)
@@ -71,7 +76,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
 #' @rdname predict
 #' @export
 #' @examples
-#'\dontrun{
+#' \dontrun{
 #' model <- glm(y ~ x, trainingData)
 #' predicted <- predict(model, testData)
 #' showDF(predicted)
@@ -81,6 +86,26 @@ setMethod("predict", signature(object = "PipelineModel"),
             return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
           })
 
+#' Make predictions from a naive Bayes model
+#'
+#' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict.
+#'
+#' @param object A fitted naive Bayes model
+#' @param newData DataFrame for testing
+#' @return DataFrame containing predicted labels in a column named "prediction"
+#' @rdname predict
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- naiveBayes(y ~ x, trainingData)
+#' predicted <- predict(model, testData)
+#' showDF(predicted)
+#'}
+setMethod("predict", signature(object = "NaiveBayesModel"),
+          function(object, newData) {
+            return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+          })
+
 #' Get the summary of a model
 #'
 #' Returns the summary of a model produced by glm(), similarly to R's summary().
@@ -97,7 +122,7 @@ setMethod("predict", signature(object = "PipelineModel"),
 #' @rdname summary
 #' @export
 #' @examples
-#'\dontrun{
+#' \dontrun{
 #' model <- glm(y ~ x, trainingData)
 #' summary(model)
 #'}
@@ -140,6 +165,35 @@ setMethod("summary", signature(object = "PipelineModel"),
             }
           })
 
+#' Get the summary of a naive Bayes model
+#'
+#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
+#'
+#' @param object A fitted MLlib model
+#' @return a list containing 'apriori', the label distribution, and 'tables', conditional
+#          probabilities given the target label
+#' @rdname summary
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- naiveBayes(y ~ x, trainingData)
+#' summary(model)
+#'}
+setMethod("summary", signature(object = "NaiveBayesModel"),
+          function(object, ...) {
+            jobj <- object@jobj
+            features <- callJMethod(jobj, "features")
+            labels <- callJMethod(jobj, "labels")
+            apriori <- callJMethod(jobj, "apriori")
+            apriori <- t(as.matrix(unlist(apriori)))
+            colnames(apriori) <- unlist(labels)
+            tables <- callJMethod(jobj, "tables")
+            tables <- matrix(tables, nrow = length(labels))
+            rownames(tables) <- unlist(labels)
+            colnames(tables) <- unlist(features)
+            return(list(apriori = apriori, tables = tables))
+          })
+
 #' Fit a k-means model
 #'
 #' Fit a k-means model, similarly to R's kmeans().
@@ -152,7 +206,7 @@ setMethod("summary", signature(object = "PipelineModel"),
 #' @rdname kmeans
 #' @export
 #' @examples
-#'\dontrun{
+#' \dontrun{
 #' model <- kmeans(x, centers = 2, algorithm="random")
 #'}
 setMethod("kmeans", signature(x = "DataFrame"),
@@ -173,7 +227,7 @@ setMethod("kmeans", signature(x = "DataFrame"),
 #' @rdname fitted
 #' @export
 #' @examples
-#'\dontrun{
+#' \dontrun{
 #' model <- kmeans(trainingData, 2)
 #' fitted.model <- fitted(model)
 #' showDF(fitted.model)
@@ -192,3 +246,30 @@ setMethod("fitted", signature(object = "PipelineModel"),
               stop(paste("Unsupported model", modelName, sep = " "))
             }
           })
+
+#' Fit a Bernoulli naive Bayes model
+#'
+#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
+#' categorical features are supported. The input should be a DataFrame of observations instead of a
+#' contingency table.
+#'
+#' @param object A symbolic description of the model to be fitted. Currently only a few formula
+#'               operators are supported, including '~', '.', ':', '+', and '-'.
+#' @param data DataFrame for training
+#' @param laplace Smoothing parameter
+#' @return a fitted naive Bayes model
+#' @rdname naiveBayes
+#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(sqlContext, infert)
+#' model <- naiveBayes(education ~ ., df, laplace = 0)
+#'}
+setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
+          function(formula, data, laplace = 0, ...) {
+            formula <- paste(deparse(formula), collapse = "")
+            jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
+                                 formula, data@sdf, laplace)
+            return(new("NaiveBayesModel", jobj = jobj))
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/d6dc12ef/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index e120462..44b4836 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -141,3 +141,62 @@ test_that("kmeans", {
   cluster <- summary.model$cluster
   expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
 })
+
+test_that("naiveBayes", {
+  # R code to reproduce the result.
+  # We do not support instance weights yet. So we ignore the frequencies.
+  #
+  #' library(e1071)
+  #' t <- as.data.frame(Titanic)
+  #' t1 <- t[t$Freq > 0, -5]
+  #' m <- naiveBayes(Survived ~ ., data = t1)
+  #' m
+  #' predict(m, t1)
+  #
+  # -- output of 'm'
+  #
+  # A-priori probabilities:
+  # Y
+  #        No       Yes
+  # 0.4166667 0.5833333
+  #
+  # Conditional probabilities:
+  #      Class
+  # Y           1st       2nd       3rd      Crew
+  #   No  0.2000000 0.2000000 0.4000000 0.2000000
+  #   Yes 0.2857143 0.2857143 0.2857143 0.1428571
+  #
+  #      Sex
+  # Y     Male Female
+  #   No   0.5    0.5
+  #   Yes  0.5    0.5
+  #
+  #      Age
+  # Y         Child     Adult
+  #   No  0.2000000 0.8000000
+  #   Yes 0.4285714 0.5714286
+  #
+  # -- output of 'predict(m, t1)'
+  #
+  # Yes Yes Yes Yes No  No  Yes Yes No  No  Yes Yes Yes Yes Yes Yes Yes Yes No  No  Yes Yes No  No
+  #
+
+  t <- as.data.frame(Titanic)
+  t1 <- t[t$Freq > 0, -5]
+  df <- suppressWarnings(createDataFrame(sqlContext, t1))
+  m <- naiveBayes(Survived ~ ., data = df)
+  s <- summary(m)
+  expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
+  expect_equal(sum(s$apriori), 1)
+  expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6)
+  p <- collect(select(predict(m, df), "prediction"))
+  expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No",
+                               "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
+                               "Yes", "Yes", "No", "No"))
+
+  # Test e1071::naiveBayes
+  if (requireNamespace("e1071", quietly = TRUE)) {
+    expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
+    expect_equal(as.character(predict(m, t1[1, ])), "Yes")
+  }
+})

http://git-wip-us.apache.org/repos/asf/spark/blob/d6dc12ef/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
new file mode 100644
index 0000000..07383d3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
+import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.sql.DataFrame
+
+private[r] class NaiveBayesWrapper private (
+    pipeline: PipelineModel,
+    val labels: Array[String],
+    val features: Array[String]) {
+
+  import NaiveBayesWrapper._
+
+  private val naiveBayesModel: NaiveBayesModel = pipeline.stages(1).asInstanceOf[NaiveBayesModel]
+
+  lazy val apriori: Array[Double] = naiveBayesModel.pi.toArray.map(math.exp)
+
+  lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp)
+
+  def transform(dataset: DataFrame): DataFrame = {
+    pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL)
+  }
+}
+
+private[r] object NaiveBayesWrapper {
+
+  val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+  val PREDICTED_LABEL_COL = "prediction"
+
+  def fit(formula: String, data: DataFrame, laplace: Double): NaiveBayesWrapper = {
+    val rFormula = new RFormula()
+      .setFormula(formula)
+      .fit(data)
+    // get labels and feature names from output schema
+    val schema = rFormula.transform(data).schema
+    val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol))
+      .asInstanceOf[NominalAttribute]
+    val labels = labelAttr.values.get
+    val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
+      .attributes.get
+    val features = featureAttrs.map(_.name.get)
+    // assemble and fit the pipeline
+    val naiveBayes = new NaiveBayes()
+      .setSmoothing(laplace)
+      .setModelType("bernoulli")
+      .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
+    val idxToStr = new IndexToString()
+      .setInputCol(PREDICTED_LABEL_INDEX_COL)
+      .setOutputCol(PREDICTED_LABEL_COL)
+      .setLabels(labels)
+    val pipeline = new Pipeline()
+      .setStages(Array(rFormula, naiveBayes, idxToStr))
+      .fit(data)
+    new NaiveBayesWrapper(pipeline, labels, features)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org