You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/09/02 00:32:36 UTC
spark git commit: [SPARK-21729][ML][TEST] Generic test for
ProbabilisticClassifier to ensure consistent output columns
Repository: spark
Updated Branches:
refs/heads/master aba9492d2 -> 900f14f6f
[SPARK-21729][ML][TEST] Generic test for ProbabilisticClassifier to ensure consistent output columns
## What changes were proposed in this pull request?
Add test for prediction using the model with all combinations of output columns turned on/off.
Make sure the output column values match, presumably by comparing vs. the case with all 3 output columns turned on.
## How was this patch tested?
Test updated.
Author: WeichenXu <we...@databricks.com>
Author: WeichenXu <We...@outlook.com>
Closes #19065 from WeichenXu123/generic_test_for_prob_classifier.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/900f14f6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/900f14f6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/900f14f6
Branch: refs/heads/master
Commit: 900f14f6fad50369aa849922447f60d7cf06cf2f
Parents: aba9492
Author: WeichenXu <we...@databricks.com>
Authored: Fri Sep 1 17:32:33 2017 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri Sep 1 17:32:33 2017 -0700
----------------------------------------------------------------------
.../DecisionTreeClassifierSuite.scala | 3 +
.../ml/classification/GBTClassifierSuite.scala | 3 +
.../LogisticRegressionSuite.scala | 6 ++
.../MultilayerPerceptronClassifierSuite.scala | 2 +
.../ml/classification/NaiveBayesSuite.scala | 6 ++
.../ProbabilisticClassifierSuite.scala | 60 ++++++++++++++++++++
.../RandomForestClassifierSuite.scala | 2 +
7 files changed, 82 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 918ab27..98c879e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -262,6 +262,9 @@ class DecisionTreeClassifierSuite
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
"probability prediction mismatch")
}
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, DecisionTreeClassificationModel](newTree, newData)
}
test("training with 1-category categorical feature") {
http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 1f79e0d..8000143 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -219,6 +219,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, GBTClassificationModel](gbtModel, validationDataset)
}
test("GBT parameter stepSize should be in interval (0, 1]") {
http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 6bf1253..d43c7cd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -502,6 +502,9 @@ class LogisticRegressionSuite
resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, LogisticRegressionModel](model, smallMultinomialDataset)
}
test("binary logistic regression: Predictor, Classifier methods") {
@@ -556,6 +559,9 @@ class LogisticRegressionSuite
resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, LogisticRegressionModel](model, smallBinaryDataset)
}
test("coefficients and intercept methods") {
http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index c294e4a..d3141ec 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -104,6 +104,8 @@ class MultilayerPerceptronClassifierSuite
case Row(p: Vector, e: Vector) =>
assert(p ~== e absTol 1e-3)
}
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, MultilayerPerceptronClassificationModel](model, strongDataset)
}
test("test model probability") {
http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 3a2be23..9730dd6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -160,6 +160,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val featureAndProbabilities = model.transform(validationDataset)
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "multinomial")
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, NaiveBayesModel](model, testDataset)
}
test("Naive Bayes with weighted samples") {
@@ -213,6 +216,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val featureAndProbabilities = model.transform(validationDataset)
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "bernoulli")
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, NaiveBayesModel](model, testDataset)
}
test("detect negative values") {
http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index 172c64a..4ecd5a0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.sql.{Dataset, Row}
final class TestProbabilisticClassificationModel(
override val uid: String,
@@ -91,4 +94,61 @@ object ProbabilisticClassifierSuite {
"thresholds" -> Array(0.4, 0.6)
)
+ /**
+ * Helper for testing that a ProbabilisticClassificationModel computes
+ * the same predictions across all combinations of output columns
+ * (rawPrediction/probability/prediction) turned on/off. Makes sure the
+ * output column values match by comparing vs. the case with all 3 output
+ * columns turned on.
+ */
+ def testPredictMethods[
+ FeaturesType,
+ M <: ProbabilisticClassificationModel[FeaturesType, M]](
+ model: M, testData: Dataset[_]): Unit = {
+
+ val allColModel = model.copy(ParamMap.empty)
+ .setRawPredictionCol("rawPredictionAll")
+ .setProbabilityCol("probabilityAll")
+ .setPredictionCol("predictionAll")
+ val allColResult = allColModel.transform(testData)
+
+ for (rawPredictionCol <- Seq("", "rawPredictionSingle")) {
+ for (probabilityCol <- Seq("", "probabilitySingle")) {
+ for (predictionCol <- Seq("", "predictionSingle")) {
+ val newModel = model.copy(ParamMap.empty)
+ .setRawPredictionCol(rawPredictionCol)
+ .setProbabilityCol(probabilityCol)
+ .setPredictionCol(predictionCol)
+
+ val result = newModel.transform(allColResult)
+
+ import org.apache.spark.sql.functions._
+
+ val resultRawPredictionCol =
+ if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol)
+ val resultProbabilityCol =
+ if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol)
+ val resultPredictionCol =
+ if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol)
+
+ result.select(
+ resultRawPredictionCol, col("rawPredictionAll"),
+ resultProbabilityCol, col("probabilityAll"),
+ resultPredictionCol, col("predictionAll")
+ ).collect().foreach {
+ case Row(
+ rawPredictionSingle: Vector, rawPredictionAll: Vector,
+ probabilitySingle: Vector, probabilityAll: Vector,
+ predictionSingle: Double, predictionAll: Double
+ ) => {
+ assert(rawPredictionSingle ~== rawPredictionAll relTol 1E-3)
+ assert(probabilitySingle ~== probabilityAll relTol 1E-3)
+ assert(predictionSingle ~== predictionAll relTol 1E-3)
+ }
+ }
+ }
+ }
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index ca2954d..2cca2e6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -155,6 +155,8 @@ class RandomForestClassifierSuite
"probability prediction mismatch")
assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
}
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, RandomForestClassificationModel](model, df)
}
test("Fitting without numClasses in metadata") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org