You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/11/12 14:14:15 UTC
spark git commit: [SPARK-14077][ML][FOLLOW-UP] Minor refactor and
cleanup for NaiveBayes
Repository: spark
Updated Branches:
refs/heads/master bc41d997e -> 22cb3a060
[SPARK-14077][ML][FOLLOW-UP] Minor refactor and cleanup for NaiveBayes
## What changes were proposed in this pull request?
* Refactor out ```trainWithLabelCheck``` and make ```mllib.NaiveBayes``` call into it.
* Avoid capturing the outer object for ```modelType```.
* Move ```requireNonnegativeValues``` and ```requireZeroOneBernoulliValues``` to companion object.
## How was this patch tested?
Existing tests.
Author: Yanbo Liang <yb...@gmail.com>
Closes #15826 from yanboliang/spark-14077-2.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/22cb3a06
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/22cb3a06
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/22cb3a06
Branch: refs/heads/master
Commit: 22cb3a060a440205281b71686637679645454ca6
Parents: bc41d99
Author: Yanbo Liang <yb...@gmail.com>
Authored: Sat Nov 12 06:13:22 2016 -0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Sat Nov 12 06:13:22 2016 -0800
----------------------------------------------------------------------
.../spark/ml/classification/NaiveBayes.scala | 72 ++++++++++----------
.../spark/mllib/classification/NaiveBayes.scala | 6 +-
2 files changed, 39 insertions(+), 39 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/22cb3a06/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index b03a07a..f1a7676 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -76,7 +76,7 @@ class NaiveBayes @Since("1.5.0") (
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams with DefaultParamsWritable {
- import NaiveBayes.{Bernoulli, Multinomial}
+ import NaiveBayes._
@Since("1.5.0")
def this() = this(Identifiable.randomUID("nb"))
@@ -110,21 +110,20 @@ class NaiveBayes @Since("1.5.0") (
@Since("2.1.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
+ override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
+ trainWithLabelCheck(dataset, positiveLabel = true)
+ }
+
/**
* ml assumes input labels in range [0, numClasses). But this implementation
* is also called by mllib NaiveBayes which allows other kinds of input labels
- * such as {-1, +1}. Here we use this parameter to switch between different processing logic.
- * It should be removed when we remove mllib NaiveBayes.
+ * such as {-1, +1}. `positiveLabel` is used to determine whether the label
+ * should be checked and it should be removed when we remove mllib NaiveBayes.
*/
- private[spark] var isML: Boolean = true
-
- private[spark] def setIsML(isML: Boolean): this.type = {
- this.isML = isML
- this
- }
-
- override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
- if (isML) {
+ private[spark] def trainWithLabelCheck(
+ dataset: Dataset[_],
+ positiveLabel: Boolean): NaiveBayesModel = {
+ if (positiveLabel) {
val numClasses = getNumClasses(dataset)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -133,28 +132,9 @@ class NaiveBayes @Since("1.5.0") (
}
}
- val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
- val values = v match {
- case sv: SparseVector => sv.values
- case dv: DenseVector => dv.values
- }
-
- require(values.forall(_ >= 0.0),
- s"Naive Bayes requires nonnegative feature values but found $v.")
- }
-
- val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
- val values = v match {
- case sv: SparseVector => sv.values
- case dv: DenseVector => dv.values
- }
-
- require(values.forall(v => v == 0.0 || v == 1.0),
- s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
- }
-
+ val modelTypeValue = $(modelType)
val requireValues: Vector => Unit = {
- $(modelType) match {
+ modelTypeValue match {
case Multinomial =>
requireNonnegativeValues
case Bernoulli =>
@@ -226,13 +206,33 @@ class NaiveBayes @Since("1.5.0") (
@Since("1.6.0")
object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
/** String name for multinomial model type. */
- private[spark] val Multinomial: String = "multinomial"
+ private[classification] val Multinomial: String = "multinomial"
/** String name for Bernoulli model type. */
- private[spark] val Bernoulli: String = "bernoulli"
+ private[classification] val Bernoulli: String = "bernoulli"
/* Set of modelTypes that NaiveBayes supports */
- private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)
+ private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
+
+ private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = {
+ val values = v match {
+ case sv: SparseVector => sv.values
+ case dv: DenseVector => dv.values
+ }
+
+ require(values.forall(_ >= 0.0),
+ s"Naive Bayes requires nonnegative feature values but found $v.")
+ }
+
+ private[NaiveBayes] def requireZeroOneBernoulliValues(v: Vector): Unit = {
+ val values = v match {
+ case sv: SparseVector => sv.values
+ case dv: DenseVector => dv.values
+ }
+
+ require(values.forall(v => v == 0.0 || v == 1.0),
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
+ }
@Since("1.6.0")
override def load(path: String): NaiveBayes = super.load(path)
http://git-wip-us.apache.org/repos/asf/spark/blob/22cb3a06/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 33561be..767d056 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -364,12 +364,12 @@ class NaiveBayes private (
val nb = new NewNaiveBayes()
.setModelType(modelType)
.setSmoothing(lambda)
- .setIsML(false)
val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) }
.toDF("label", "features")
- val newModel = nb.fit(dataset)
+ // mllib NaiveBayes allows input labels like {-1, +1}, so set `positiveLabel` as false.
+ val newModel = nb.trainWithLabelCheck(dataset, positiveLabel = false)
val pi = newModel.pi.toArray
val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0)
@@ -378,7 +378,7 @@ class NaiveBayes private (
theta(i)(j) = v
}
- require(newModel.oldLabels != null,
+ assert(newModel.oldLabels != null,
"The underlying ML NaiveBayes training does not produce labels.")
new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org