You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/10/04 13:55:01 UTC
spark git commit: [SPARK-17744][ML] Parity check between the ml and
mllib test suites for NB
Repository: spark
Updated Branches:
refs/heads/master 7d5160883 -> c17f97183
[SPARK-17744][ML] Parity check between the ml and mllib test suites for NB
## What changes were proposed in this pull request?
1,parity check and add missing test suites for ml's NB
2,remove some unused imports
## How was this patch tested?
manual tests in spark-shell
Author: Zheng RuiFeng <ru...@foxmail.com>
Closes #15312 from zhengruifeng/nb_test_parity.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c17f9718
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c17f9718
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c17f9718
Branch: refs/heads/master
Commit: c17f971839816e68f8abe2c8eb4e4db47c57ab67
Parents: 7d51608
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Tue Oct 4 06:54:48 2016 -0700
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Tue Oct 4 06:54:48 2016 -0700
----------------------------------------------------------------------
.../apache/spark/ml/feature/LabeledPoint.scala | 2 +-
.../spark/ml/feature/QuantileDiscretizer.scala | 2 +-
.../org/apache/spark/ml/python/MLSerDe.scala | 5 --
.../spark/ml/regression/GBTRegressor.scala | 2 +-
.../spark/ml/regression/LinearRegression.scala | 1 -
.../ml/classification/NaiveBayesSuite.scala | 69 +++++++++++++++++++-
python/pyspark/ml/classification.py | 1 -
7 files changed, 70 insertions(+), 12 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c17f9718/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
index 6cefa70..7d8e4ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.Vector
/**
* :: Experimental ::
*
- * Class that represents the features and labels of a data point.
+ * Class that represents the features and label of a data point.
*
* @param label Label for this data point.
* @param features List of features for this data point.
http://git-wip-us.apache.org/repos/asf/spark/blob/c17f9718/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 1e59d71..05e034d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.Dataset
-import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.sql.types.StructType
/**
* Params for [[QuantileDiscretizer]].
http://git-wip-us.apache.org/repos/asf/spark/blob/c17f9718/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
index 4b805e1..da62f85 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
@@ -19,17 +19,12 @@ package org.apache.spark.ml.python
import java.io.OutputStream
import java.nio.{ByteBuffer, ByteOrder}
-import java.util.{ArrayList => JArrayList}
-
-import scala.collection.JavaConverters._
import net.razorvine.pickle._
-import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.ml.linalg._
import org.apache.spark.mllib.api.python.SerDeBase
-import org.apache.spark.rdd.RDD
/**
* SerDe utility functions for pyspark.ml.
http://git-wip-us.apache.org/repos/asf/spark/blob/c17f9718/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index ce35593..bb01f9d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -21,7 +21,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
-import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
http://git-wip-us.apache.org/repos/asf/spark/blob/c17f9718/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 7fddfd9..536c58f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -37,7 +37,6 @@ import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.evaluation.RegressionMetrics
-import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
http://git-wip-us.apache.org/repos/asf/spark/blob/c17f9718/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 597428d..e934e5e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -22,10 +22,10 @@ import scala.util.Random
import breeze.linalg.{DenseVector => BDV, Vector => BV}
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial}
import org.apache.spark.ml.classification.NaiveBayesSuite._
-import org.apache.spark.ml.feature.{Instance, LabeledPoint}
+import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@@ -106,6 +106,11 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
}
+ test("model types") {
+ assert(Multinomial === "multinomial")
+ assert(Bernoulli === "bernoulli")
+ }
+
test("params") {
ParamsSuite.checkParams(new NaiveBayes)
val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
@@ -228,6 +233,66 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
validateProbabilities(featureAndProbabilities, model, "bernoulli")
}
+ test("detect negative values") {
+ val dense = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(-1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0))))
+ intercept[SparkException] {
+ new NaiveBayes().fit(dense)
+ }
+ val sparse = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(-1.0))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty))))
+ intercept[SparkException] {
+ new NaiveBayes().fit(sparse)
+ }
+ val nan = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty))))
+ intercept[SparkException] {
+ new NaiveBayes().fit(nan)
+ }
+ }
+
+ test("detect non zero or one values in Bernoulli") {
+ val badTrain = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0))))
+
+ intercept[SparkException] {
+ new NaiveBayes().setModelType(Bernoulli).setSmoothing(1.0).fit(badTrain)
+ }
+
+ val okTrain = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0))))
+
+ val model = new NaiveBayes().setModelType(Bernoulli).setSmoothing(1.0).fit(okTrain)
+
+ val badPredict = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0))))
+
+ intercept[SparkException] {
+ model.transform(badPredict).collect()
+ }
+ }
+
test("read/write") {
def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = {
assert(model.pi === model2.pi)
http://git-wip-us.apache.org/repos/asf/spark/blob/c17f9718/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 505e7bf..ea60fab 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -16,7 +16,6 @@
#
import operator
-import warnings
from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org