You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/09/02 00:32:36 UTC

spark git commit: [SPARK-21729][ML][TEST] Generic test for ProbabilisticClassifier to ensure consistent output columns

Repository: spark
Updated Branches:
  refs/heads/master aba9492d2 -> 900f14f6f


[SPARK-21729][ML][TEST] Generic test for ProbabilisticClassifier to ensure consistent output columns

## What changes were proposed in this pull request?

Add test for prediction using the model with all combinations of output columns turned on/off.
Make sure the output column values match, presumably by comparing vs. the case with all 3 output columns turned on.

## How was this patch tested?

Test updated.

Author: WeichenXu <we...@databricks.com>
Author: WeichenXu <We...@outlook.com>

Closes #19065 from WeichenXu123/generic_test_for_prob_classifier.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/900f14f6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/900f14f6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/900f14f6

Branch: refs/heads/master
Commit: 900f14f6fad50369aa849922447f60d7cf06cf2f
Parents: aba9492
Author: WeichenXu <we...@databricks.com>
Authored: Fri Sep 1 17:32:33 2017 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Fri Sep 1 17:32:33 2017 -0700

----------------------------------------------------------------------
 .../DecisionTreeClassifierSuite.scala           |  3 +
 .../ml/classification/GBTClassifierSuite.scala  |  3 +
 .../LogisticRegressionSuite.scala               |  6 ++
 .../MultilayerPerceptronClassifierSuite.scala   |  2 +
 .../ml/classification/NaiveBayesSuite.scala     |  6 ++
 .../ProbabilisticClassifierSuite.scala          | 60 ++++++++++++++++++++
 .../RandomForestClassifierSuite.scala           |  2 +
 7 files changed, 82 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 918ab27..98c879e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -262,6 +262,9 @@ class DecisionTreeClassifierSuite
       assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
         "probability prediction mismatch")
     }
+
+    ProbabilisticClassifierSuite.testPredictMethods[
+      Vector, DecisionTreeClassificationModel](newTree, newData)
   }
 
   test("training with 1-category categorical feature") {

http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 1f79e0d..8000143 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -219,6 +219,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
     resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
       case (pred1, pred2) => assert(pred1 === pred2)
     }
+
+    ProbabilisticClassifierSuite.testPredictMethods[
+      Vector, GBTClassificationModel](gbtModel, validationDataset)
   }
 
   test("GBT parameter stepSize should be in interval (0, 1]") {

http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 6bf1253..d43c7cd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -502,6 +502,9 @@ class LogisticRegressionSuite
     resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
       case (pred1, pred2) => assert(pred1 === pred2)
     }
+
+    ProbabilisticClassifierSuite.testPredictMethods[
+      Vector, LogisticRegressionModel](model, smallMultinomialDataset)
   }
 
   test("binary logistic regression: Predictor, Classifier methods") {
@@ -556,6 +559,9 @@ class LogisticRegressionSuite
     resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
       case (pred1, pred2) => assert(pred1 === pred2)
     }
+
+    ProbabilisticClassifierSuite.testPredictMethods[
+      Vector, LogisticRegressionModel](model, smallBinaryDataset)
   }
 
   test("coefficients and intercept methods") {

http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index c294e4a..d3141ec 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -104,6 +104,8 @@ class MultilayerPerceptronClassifierSuite
       case Row(p: Vector, e: Vector) =>
         assert(p ~== e absTol 1e-3)
     }
+    ProbabilisticClassifierSuite.testPredictMethods[
+      Vector, MultilayerPerceptronClassificationModel](model, strongDataset)
   }
 
   test("test model probability") {

http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 3a2be23..9730dd6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -160,6 +160,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
     val featureAndProbabilities = model.transform(validationDataset)
       .select("features", "probability")
     validateProbabilities(featureAndProbabilities, model, "multinomial")
+
+    ProbabilisticClassifierSuite.testPredictMethods[
+      Vector, NaiveBayesModel](model, testDataset)
   }
 
   test("Naive Bayes with weighted samples") {
@@ -213,6 +216,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
     val featureAndProbabilities = model.transform(validationDataset)
       .select("features", "probability")
     validateProbabilities(featureAndProbabilities, model, "bernoulli")
+
+    ProbabilisticClassifierSuite.testPredictMethods[
+      Vector, NaiveBayesModel](model, testDataset)
   }
 
   test("detect negative values") {

http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index 172c64a..4ecd5a0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark.ml.classification
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.sql.{Dataset, Row}
 
 final class TestProbabilisticClassificationModel(
     override val uid: String,
@@ -91,4 +94,61 @@ object ProbabilisticClassifierSuite {
     "thresholds" -> Array(0.4, 0.6)
   )
 
+  /**
+   * Helper for testing that a ProbabilisticClassificationModel computes
+   * the same predictions across all combinations of output columns
+   * (rawPrediction/probability/prediction) turned on/off. Makes sure the
+   * output column values match by comparing vs. the case with all 3 output
+   * columns turned on.
+   */
+  def testPredictMethods[
+      FeaturesType,
+      M <: ProbabilisticClassificationModel[FeaturesType, M]](
+    model: M, testData: Dataset[_]): Unit = {
+
+    val allColModel = model.copy(ParamMap.empty)
+      .setRawPredictionCol("rawPredictionAll")
+      .setProbabilityCol("probabilityAll")
+      .setPredictionCol("predictionAll")
+    val allColResult = allColModel.transform(testData)
+
+    for (rawPredictionCol <- Seq("", "rawPredictionSingle")) {
+      for (probabilityCol <- Seq("", "probabilitySingle")) {
+        for (predictionCol <- Seq("", "predictionSingle")) {
+          val newModel = model.copy(ParamMap.empty)
+            .setRawPredictionCol(rawPredictionCol)
+            .setProbabilityCol(probabilityCol)
+            .setPredictionCol(predictionCol)
+
+          val result = newModel.transform(allColResult)
+
+          import org.apache.spark.sql.functions._
+
+          val resultRawPredictionCol =
+            if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol)
+          val resultProbabilityCol =
+            if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol)
+          val resultPredictionCol =
+            if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol)
+
+          result.select(
+            resultRawPredictionCol, col("rawPredictionAll"),
+            resultProbabilityCol, col("probabilityAll"),
+            resultPredictionCol, col("predictionAll")
+          ).collect().foreach {
+            case Row(
+              rawPredictionSingle: Vector, rawPredictionAll: Vector,
+              probabilitySingle: Vector, probabilityAll: Vector,
+              predictionSingle: Double, predictionAll: Double
+            ) => {
+              assert(rawPredictionSingle ~== rawPredictionAll relTol 1E-3)
+              assert(probabilitySingle ~== probabilityAll relTol 1E-3)
+              assert(predictionSingle ~== predictionAll relTol 1E-3)
+            }
+          }
+        }
+      }
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index ca2954d..2cca2e6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -155,6 +155,8 @@ class RandomForestClassifierSuite
         "probability prediction mismatch")
       assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
     }
+    ProbabilisticClassifierSuite.testPredictMethods[
+      Vector, RandomForestClassificationModel](model, df)
   }
 
   test("Fitting without numClasses in metadata") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org