You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/03/05 18:49:58 UTC

spark git commit: [SPARK-22882][ML][TESTS] ML test for structured streaming: ml.classification

Repository: spark
Updated Branches:
  refs/heads/master 4586eada4 -> 98a5c0a35


[SPARK-22882][ML][TESTS] ML test for structured streaming: ml.classification

## What changes were proposed in this pull request?

adding Structured Streaming tests for all Models/Transformers in spark.ml.classification

## How was this patch tested?

N/A

Author: WeichenXu <we...@databricks.com>

Closes #20121 from WeichenXu123/ml_stream_test_classification.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/98a5c0a3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/98a5c0a3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/98a5c0a3

Branch: refs/heads/master
Commit: 98a5c0a35f0a24730f5074522939acf57ef95422
Parents: 4586ead
Author: WeichenXu <we...@databricks.com>
Authored: Mon Mar 5 10:50:00 2018 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Mar 5 10:50:00 2018 -0800

----------------------------------------------------------------------
 .../DecisionTreeClassifierSuite.scala           |  29 ++-
 .../ml/classification/GBTClassifierSuite.scala  |  77 ++-----
 .../ml/classification/LinearSVCSuite.scala      |  15 +-
 .../LogisticRegressionSuite.scala               | 229 +++++++------------
 .../MultilayerPerceptronClassifierSuite.scala   |  44 ++--
 .../ml/classification/NaiveBayesSuite.scala     |  47 ++--
 .../ml/classification/OneVsRestSuite.scala      |  21 +-
 .../ProbabilisticClassifierSuite.scala          |  29 +--
 .../RandomForestClassifierSuite.scala           |  16 +-
 9 files changed, 202 insertions(+), 305 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/98a5c0a3/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 38b265d..eeb0324 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -23,15 +23,14 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
 import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
+  DecisionTreeSuite => OldDecisionTreeSuite}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 
-class DecisionTreeClassifierSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
 
   import DecisionTreeClassifierSuite.compareAPIs
   import testImplicits._
@@ -251,20 +250,18 @@ class DecisionTreeClassifierSuite
 
     MLTestingUtils.checkCopyAndUids(dt, newTree)
 
-    val predictions = newTree.transform(newData)
-      .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
-      .collect()
-
-    predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
-      assert(pred === rawPred.argmax,
-        s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
-      val sum = rawPred.toArray.sum
-      assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
-        "probability prediction mismatch")
+    testTransformer[(Vector, Double)](newData, newTree,
+      "prediction", "rawPrediction", "probability") {
+      case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+        assert(pred === rawPred.argmax,
+          s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
+        val sum = rawPred.toArray.sum
+        assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
+          "probability prediction mismatch")
     }
 
     ProbabilisticClassifierSuite.testPredictMethods[
-      Vector, DecisionTreeClassificationModel](newTree, newData)
+      Vector, DecisionTreeClassificationModel](this, newTree, newData)
   }
 
   test("training with 1-category categorical feature") {

http://git-wip-us.apache.org/repos/asf/spark/blob/98a5c0a3/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 978f89c..092b4a0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -26,13 +26,12 @@ import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.loss.LogLoss
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.util.Utils
@@ -40,8 +39,7 @@ import org.apache.spark.util.Utils
 /**
  * Test suite for [[GBTClassifier]].
  */
-class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
-  with DefaultReadWriteTest {
+class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
   import GBTClassifierSuite.compareAPIs
@@ -126,14 +124,15 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
 
     // should predict all zeros
     binaryModel.setThresholds(Array(0.0, 1.0))
-    val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect()
-    assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0))
+    testTransformer[(Double, Vector)](df, binaryModel, "prediction") {
+      case Row(prediction: Double) => prediction === 0.0
+    }
 
     // should predict all ones
     binaryModel.setThresholds(Array(1.0, 0.0))
-    val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect()
-    assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0))
-
+    testTransformer[(Double, Vector)](df, binaryModel, "prediction") {
+      case Row(prediction: Double) => prediction === 1.0
+    }
 
     val gbtBase = new GBTClassifier
     val model = gbtBase.fit(df)
@@ -141,15 +140,18 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
 
     // constant threshold scaling is the same as no thresholds
     binaryModel.setThresholds(Array(1.0, 1.0))
-    val scaledPredictions = binaryModel.transform(df).select("prediction").collect()
-    assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
-      scaled.getDouble(0) === base.getDouble(0)
-    })
+    testTransformerByGlobalCheckFunc[(Double, Vector)](df, binaryModel, "prediction") {
+      scaledPredictions: Seq[Row] =>
+        assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
+          scaled.getDouble(0) === base.getDouble(0)
+        })
+    }
 
     // force it to use the predict method
     model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1))
-    val predictionsWithPredict = model.transform(df).select("prediction").collect()
-    assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0))
+    testTransformer[(Double, Vector)](df, model, "prediction") {
+      case Row(prediction: Double) => prediction === 0.0
+    }
   }
 
   test("GBTClassifier: Predictor, Classifier methods") {
@@ -169,61 +171,30 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
     val blas = BLAS.getInstance()
 
     val validationDataset = validationData.toDF(labelCol, featuresCol)
-    val results = gbtModel.transform(validationDataset)
-    // check that raw prediction is tree predictions dot tree weights
-    results.select(rawPredictionCol, featuresCol).collect().foreach {
-      case Row(raw: Vector, features: Vector) =>
+    testTransformer[(Double, Vector)](validationDataset, gbtModel,
+      "rawPrediction", "features", "probability", "prediction") {
+      case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) =>
         assert(raw.size === 2)
+        // check that raw prediction is tree predictions dot tree weights
         val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
         val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
         assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
-    }
 
-    // Compare rawPrediction with probability
-    results.select(rawPredictionCol, probabilityCol).collect().foreach {
-      case Row(raw: Vector, prob: Vector) =>
-        assert(raw.size === 2)
+        // Compare rawPrediction with probability
         assert(prob.size === 2)
         // Note: we should check other loss types for classification if they are added
         val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value))
         assert(prob(0) ~== predFromRaw(0) relTol eps)
         assert(prob(1) ~== predFromRaw(1) relTol eps)
         assert(prob(0) + prob(1) ~== 1.0 absTol absEps)
-    }
 
-    // Compare prediction with probability
-    results.select(predictionCol, probabilityCol).collect().foreach {
-      case Row(pred: Double, prob: Vector) =>
+        // Compare prediction with probability
         val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
         assert(pred == predFromProb)
     }
 
-    // force it to use raw2prediction
-    gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("")
-    val resultsUsingRaw2Predict =
-      gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
-    resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
-      case (pred1, pred2) => assert(pred1 === pred2)
-    }
-
-    // force it to use probability2prediction
-    gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol)
-    val resultsUsingProb2Predict =
-      gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
-    resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
-      case (pred1, pred2) => assert(pred1 === pred2)
-    }
-
-    // force it to use predict
-    gbtModel.setRawPredictionCol("").setProbabilityCol("")
-    val resultsUsingPredict =
-      gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
-    resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
-      case (pred1, pred2) => assert(pred1 === pred2)
-    }
-
     ProbabilisticClassifierSuite.testPredictMethods[
-      Vector, GBTClassificationModel](gbtModel, validationDataset)
+      Vector, GBTClassificationModel](this, gbtModel, validationDataset)
   }
 
   test("GBT parameter stepSize should be in interval (0, 1]") {

http://git-wip-us.apache.org/repos/asf/spark/blob/98a5c0a3/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index 41a5d22..a93825b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -21,20 +21,18 @@ import scala.util.Random
 
 import breeze.linalg.{DenseVector => BDV}
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.classification.LinearSVCSuite._
 import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.ml.optim.aggregator.HingeAggregator
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.functions.udf
 
 
-class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -141,10 +139,11 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
         threshold: Double,
         expected: Set[(Int, Double)]): Unit = {
       model.setThreshold(threshold)
-      val results = model.transform(df).select("id", "prediction").collect()
-        .map(r => (r.getInt(0), r.getDouble(1)))
-        .toSet
-      assert(results === expected, s"Failed for threshold = $threshold")
+      testTransformerByGlobalCheckFunc[(Int, Vector)](df, model, "id", "prediction") {
+        rows: Seq[Row] =>
+          val results = rows.map(r => (r.getInt(0), r.getDouble(1))).toSet
+          assert(results === expected, s"Failed for threshold = $threshold")
+      }
     }
 
     def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/98a5c0a3/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a5f81a3..9987cbf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -22,22 +22,20 @@ import scala.language.existentials
 import scala.util.Random
 import scala.util.control.Breaks._
 
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkException
 import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.classification.LogisticRegressionSuite._
 import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors}
 import org.apache.spark.ml.optim.aggregator.LogisticAggregator
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.functions.{col, lit, rand}
 import org.apache.spark.sql.types.LongType
 
-class LogisticRegressionSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -332,15 +330,14 @@ class LogisticRegressionSuite
     val binaryModel = blr.fit(smallBinaryDataset)
 
     binaryModel.setThreshold(1.0)
-    val binaryZeroPredictions =
-      binaryModel.transform(smallBinaryDataset).select("prediction").collect()
-    assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0))
+    testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") {
+      row => assert(row.getDouble(0) === 0.0)
+    }
 
     binaryModel.setThreshold(0.0)
-    val binaryOnePredictions =
-      binaryModel.transform(smallBinaryDataset).select("prediction").collect()
-    assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0))
-
+    testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") {
+      row => assert(row.getDouble(0) === 1.0)
+    }
 
     val mlr = new LogisticRegression().setFamily("multinomial")
     val model = mlr.fit(smallMultinomialDataset)
@@ -348,31 +345,36 @@ class LogisticRegressionSuite
 
     // should predict all zeros
     model.setThresholds(Array(1, 1000, 1000))
-    val zeroPredictions = model.transform(smallMultinomialDataset).select("prediction").collect()
-    assert(zeroPredictions.forall(_.getDouble(0) === 0.0))
+    testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") {
+      row => assert(row.getDouble(0) === 0.0)
+    }
 
     // should predict all ones
     model.setThresholds(Array(1000, 1, 1000))
-    val onePredictions = model.transform(smallMultinomialDataset).select("prediction").collect()
-    assert(onePredictions.forall(_.getDouble(0) === 1.0))
+    testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") {
+      row => assert(row.getDouble(0) === 1.0)
+    }
 
     // should predict all twos
     model.setThresholds(Array(1000, 1000, 1))
-    val twoPredictions = model.transform(smallMultinomialDataset).select("prediction").collect()
-    assert(twoPredictions.forall(_.getDouble(0) === 2.0))
+    testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") {
+      row => assert(row.getDouble(0) === 2.0)
+    }
 
     // constant threshold scaling is the same as no thresholds
     model.setThresholds(Array(1000, 1000, 1000))
-    val scaledPredictions = model.transform(smallMultinomialDataset).select("prediction").collect()
-    assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
-      scaled.getDouble(0) === base.getDouble(0)
-    })
+    testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model,
+      "prediction") { scaledPredictions: Seq[Row] =>
+      assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
+        scaled.getDouble(0) === base.getDouble(0)
+      })
+    }
 
     // force it to use the predict method
     model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1, 1))
-    val predictionsWithPredict =
-      model.transform(smallMultinomialDataset).select("prediction").collect()
-    assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0))
+    testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") {
+      row => assert(row.getDouble(0) === 0.0)
+    }
   }
 
   test("logistic regression doesn't fit intercept when fitIntercept is off") {
@@ -403,21 +405,19 @@ class LogisticRegressionSuite
 
     // Modify model params, and check that the params worked.
     model.setThreshold(1.0)
-    val predAllZero = model.transform(smallBinaryDataset)
-      .select("prediction", "myProbability")
-      .collect()
-      .map { case Row(pred: Double, prob: Vector) => pred }
-    assert(predAllZero.forall(_ === 0),
-      s"With threshold=1.0, expected predictions to be all 0, but only" +
-      s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.")
+    testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(),
+      model, "prediction", "myProbability") { rows =>
+      val predAllZero = rows.map(_.getDouble(0))
+      assert(predAllZero.forall(_ === 0),
+        s"With threshold=1.0, expected predictions to be all 0, but only" +
+        s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.")
+    }
     // Call transform with params, and check that the params worked.
-    val predNotAllZero =
-      model.transform(smallBinaryDataset, model.threshold -> 0.0,
-        model.probabilityCol -> "myProb")
-        .select("prediction", "myProb")
-        .collect()
-        .map { case Row(pred: Double, prob: Vector) => pred }
-    assert(predNotAllZero.exists(_ !== 0.0))
+    testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(),
+      model.copy(ParamMap(model.threshold -> 0.0,
+        model.probabilityCol -> "myProb")), "prediction", "myProb") {
+      rows => assert(rows.map(_.getDouble(0)).exists(_ !== 0.0))
+    }
 
     // Call fit() with new params, and check as many params as we can.
     lr.setThresholds(Array(0.6, 0.4))
@@ -441,10 +441,10 @@ class LogisticRegressionSuite
     val numFeatures = smallMultinomialDataset.select("features").first().getAs[Vector](0).size
     assert(model.numFeatures === numFeatures)
 
-    val results = model.transform(smallMultinomialDataset)
-    // check that raw prediction is coefficients dot features + intercept
-    results.select("rawPrediction", "features").collect().foreach {
-      case Row(raw: Vector, features: Vector) =>
+    testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(),
+      model, "rawPrediction", "features", "probability") {
+      case Row(raw: Vector, features: Vector, prob: Vector) =>
+        // check that raw prediction is coefficients dot features + intercept
         assert(raw.size === 3)
         val margins = Array.tabulate(3) { k =>
           var margin = 0.0
@@ -455,12 +455,7 @@ class LogisticRegressionSuite
           margin
         }
         assert(raw ~== Vectors.dense(margins) relTol eps)
-    }
-
-    // Compare rawPrediction with probability
-    results.select("rawPrediction", "probability").collect().foreach {
-      case Row(raw: Vector, prob: Vector) =>
-        assert(raw.size === 3)
+        // Compare rawPrediction with probability
         assert(prob.size === 3)
         val max = raw.toArray.max
         val subtract = if (max > 0) max else 0.0
@@ -472,39 +467,8 @@ class LogisticRegressionSuite
         assert(prob(2) ~== 1.0 - probFromRaw1 - probFromRaw0 relTol eps)
     }
 
-    // Compare prediction with probability
-    results.select("prediction", "probability").collect().foreach {
-      case Row(pred: Double, prob: Vector) =>
-        val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
-        assert(pred == predFromProb)
-    }
-
-    // force it to use raw2prediction
-    model.setRawPredictionCol("rawPrediction").setProbabilityCol("")
-    val resultsUsingRaw2Predict =
-      model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
-    resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
-      case (pred1, pred2) => assert(pred1 === pred2)
-    }
-
-    // force it to use probability2prediction
-    model.setRawPredictionCol("").setProbabilityCol("probability")
-    val resultsUsingProb2Predict =
-      model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
-    resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
-      case (pred1, pred2) => assert(pred1 === pred2)
-    }
-
-    // force it to use predict
-    model.setRawPredictionCol("").setProbabilityCol("")
-    val resultsUsingPredict =
-      model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
-    resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
-      case (pred1, pred2) => assert(pred1 === pred2)
-    }
-
     ProbabilisticClassifierSuite.testPredictMethods[
-      Vector, LogisticRegressionModel](model, smallMultinomialDataset)
+      Vector, LogisticRegressionModel](this, model, smallMultinomialDataset)
   }
 
   test("binary logistic regression: Predictor, Classifier methods") {
@@ -517,51 +481,22 @@ class LogisticRegressionSuite
     val numFeatures = smallBinaryDataset.select("features").first().getAs[Vector](0).size
     assert(model.numFeatures === numFeatures)
 
-    val results = model.transform(smallBinaryDataset)
-
-    // Compare rawPrediction with probability
-    results.select("rawPrediction", "probability").collect().foreach {
-      case Row(raw: Vector, prob: Vector) =>
+    testTransformer[(Double, Vector)](smallBinaryDataset.toDF(),
+      model, "rawPrediction", "probability", "prediction") {
+      case Row(raw: Vector, prob: Vector, pred: Double) =>
+        // Compare rawPrediction with probability
         assert(raw.size === 2)
         assert(prob.size === 2)
         val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1)))
         assert(prob(1) ~== probFromRaw1 relTol eps)
         assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps)
-    }
-
-    // Compare prediction with probability
-    results.select("prediction", "probability").collect().foreach {
-      case Row(pred: Double, prob: Vector) =>
+        // Compare prediction with probability
         val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
         assert(pred == predFromProb)
     }
 
-    // force it to use raw2prediction
-    model.setRawPredictionCol("rawPrediction").setProbabilityCol("")
-    val resultsUsingRaw2Predict =
-      model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
-    resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
-      case (pred1, pred2) => assert(pred1 === pred2)
-    }
-
-    // force it to use probability2prediction
-    model.setRawPredictionCol("").setProbabilityCol("probability")
-    val resultsUsingProb2Predict =
-      model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
-    resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
-      case (pred1, pred2) => assert(pred1 === pred2)
-    }
-
-    // force it to use predict
-    model.setRawPredictionCol("").setProbabilityCol("")
-    val resultsUsingPredict =
-      model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
-    resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
-      case (pred1, pred2) => assert(pred1 === pred2)
-    }
-
     ProbabilisticClassifierSuite.testPredictMethods[
-      Vector, LogisticRegressionModel](model, smallBinaryDataset)
+      Vector, LogisticRegressionModel](this, model, smallBinaryDataset)
   }
 
   test("coefficients and intercept methods") {
@@ -616,19 +551,21 @@ class LogisticRegressionSuite
       LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)),
       LabeledPoint(1.0, Vectors.dense(0.0, -1.0))
     ).toDF()
-    val results = model.transform(overFlowData).select("rawPrediction", "probability").collect()
-
-    // probabilities are correct when margins have to be adjusted
-    val raw1 = results(0).getAs[Vector](0)
-    val prob1 = results(0).getAs[Vector](1)
-    assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0))
-    assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps)
-
-    // probabilities are correct when margins don't have to be adjusted
-    val raw2 = results(1).getAs[Vector](0)
-    val prob2 = results(1).getAs[Vector](1)
-    assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0))
-    assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps)
+
+    testTransformerByGlobalCheckFunc[(Double, Vector)](overFlowData.toDF(),
+      model, "rawPrediction", "probability") { results: Seq[Row] =>
+        // probabilities are correct when margins have to be adjusted
+        val raw1 = results(0).getAs[Vector](0)
+        val prob1 = results(0).getAs[Vector](1)
+        assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0))
+        assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps)
+
+        // probabilities are correct when margins don't have to be adjusted
+        val raw2 = results(1).getAs[Vector](0)
+        val prob2 = results(1).getAs[Vector](1)
+        assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0))
+        assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps)
+    }
   }
 
   test("MultiClassSummarizer") {
@@ -2567,10 +2504,13 @@ class LogisticRegressionSuite
     val model1 = lr.fit(smallBinaryDataset)
     val lr2 = new LogisticRegression().setInitialModel(model1).setMaxIter(5).setFamily("binomial")
     val model2 = lr2.fit(smallBinaryDataset)
-    val predictions1 = model1.transform(smallBinaryDataset).select("prediction").collect()
-    val predictions2 = model2.transform(smallBinaryDataset).select("prediction").collect()
-    predictions1.zip(predictions2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
-      assert(p1 === p2)
+    val binaryExpected = model1.transform(smallBinaryDataset).select("prediction").collect()
+      .map(_.getDouble(0))
+    for (model <- Seq(model1, model2)) {
+      testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), model,
+        "prediction") { rows: Seq[Row] =>
+        rows.map(_.getDouble(0)).toArray === binaryExpected
+      }
     }
     assert(model2.summary.totalIterations === 1)
 
@@ -2579,10 +2519,13 @@ class LogisticRegressionSuite
     val lr4 = new LogisticRegression()
       .setInitialModel(model3).setMaxIter(5).setFamily("multinomial")
     val model4 = lr4.fit(smallMultinomialDataset)
-    val predictions3 = model3.transform(smallMultinomialDataset).select("prediction").collect()
-    val predictions4 = model4.transform(smallMultinomialDataset).select("prediction").collect()
-    predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) =>
-      assert(p1 === p2)
+    val multinomialExpected = model3.transform(smallMultinomialDataset).select("prediction")
+      .collect().map(_.getDouble(0))
+    for (model <- Seq(model3, model4)) {
+      testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model,
+        "prediction") { rows: Seq[Row] =>
+        rows.map(_.getDouble(0)).toArray === multinomialExpected
+      }
     }
     assert(model4.summary.totalIterations === 1)
   }
@@ -2638,8 +2581,8 @@ class LogisticRegressionSuite
       LabeledPoint(4.0, Vectors.dense(2.0))).toDF()
     val mlr = new LogisticRegression().setFamily("multinomial")
     val model = mlr.fit(constantData)
-    val results = model.transform(constantData)
-    results.select("rawPrediction", "probability", "prediction").collect().foreach {
+    testTransformer[(Double, Vector)](constantData, model,
+      "rawPrediction", "probability", "prediction") {
       case Row(raw: Vector, prob: Vector, pred: Double) =>
         assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity)))
         assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0)))
@@ -2653,8 +2596,8 @@ class LogisticRegressionSuite
       LabeledPoint(0.0, Vectors.dense(1.0)),
       LabeledPoint(0.0, Vectors.dense(2.0))).toDF()
     val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData)
-    val resultsZero = modelZeroLabel.transform(constantZeroData)
-    resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach {
+    testTransformer[(Double, Vector)](constantZeroData, modelZeroLabel,
+      "rawPrediction", "probability", "prediction") {
       case Row(raw: Vector, prob: Vector, pred: Double) =>
         assert(prob === Vectors.dense(Array(1.0)))
         assert(pred === 0.0)
@@ -2666,8 +2609,8 @@ class LogisticRegressionSuite
     val constantDataWithMetadata = constantData
       .select(constantData("label").as("label", labelMeta), constantData("features"))
     val modelWithMetadata = mlr.setFitIntercept(true).fit(constantDataWithMetadata)
-    val resultsWithMetadata = modelWithMetadata.transform(constantDataWithMetadata)
-    resultsWithMetadata.select("rawPrediction", "probability", "prediction").collect().foreach {
+    testTransformer[(Double, Vector)](constantDataWithMetadata, modelWithMetadata,
+      "rawPrediction", "probability", "prediction") {
       case Row(raw: Vector, prob: Vector, pred: Double) =>
         assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity, 0.0)))
         assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0)))

http://git-wip-us.apache.org/repos/asf/spark/blob/98a5c0a3/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index d3141ec..daa58a5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -17,22 +17,17 @@
 
 package org.apache.spark.ml.classification
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.classification.LogisticRegressionSuite._
 import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
 import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions._
 
-class MultilayerPerceptronClassifierSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -75,11 +70,9 @@ class MultilayerPerceptronClassifierSuite
       .setMaxIter(100)
       .setSolver("l-bfgs")
     val model = trainer.fit(dataset)
-    val result = model.transform(dataset)
     MLTestingUtils.checkCopyAndUids(trainer, model)
-    val predictionAndLabels = result.select("prediction", "label").collect()
-    predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
-      assert(p == l)
+    testTransformer[(Vector, Double)](dataset.toDF(), model, "prediction", "label") {
+      case Row(p: Double, l: Double) => assert(p == l)
     }
   }
 
@@ -99,13 +92,12 @@ class MultilayerPerceptronClassifierSuite
       .setMaxIter(100)
       .setSolver("l-bfgs")
     val model = trainer.fit(strongDataset)
-    val result = model.transform(strongDataset)
-    result.select("probability", "expectedProbability").collect().foreach {
-      case Row(p: Vector, e: Vector) =>
-        assert(p ~== e absTol 1e-3)
+    testTransformer[(Vector, Double, Vector)](strongDataset.toDF(), model,
+      "probability", "expectedProbability") {
+      case Row(p: Vector, e: Vector) => assert(p ~== e absTol 1e-3)
     }
     ProbabilisticClassifierSuite.testPredictMethods[
-      Vector, MultilayerPerceptronClassificationModel](model, strongDataset)
+      Vector, MultilayerPerceptronClassificationModel](this, model, strongDataset)
   }
 
   test("test model probability") {
@@ -118,11 +110,10 @@ class MultilayerPerceptronClassifierSuite
       .setSolver("l-bfgs")
     val model = trainer.fit(dataset)
     model.setProbabilityCol("probability")
-    val result = model.transform(dataset)
-    val features2prob = udf { features: Vector => model.mlpModel.predict(features) }
-    result.select(features2prob(col("features")), col("probability")).collect().foreach {
-      case Row(p1: Vector, p2: Vector) =>
-        assert(p1 ~== p2 absTol 1e-3)
+    testTransformer[(Vector, Double)](dataset.toDF(), model, "features", "probability") {
+      case Row(features: Vector, prob: Vector) =>
+        val prob2 = model.mlpModel.predict(features)
+        assert(prob ~== prob2 absTol 1e-3)
     }
   }
 
@@ -175,9 +166,6 @@ class MultilayerPerceptronClassifierSuite
     val model = trainer.fit(dataFrame)
     val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size
     assert(model.numFeatures === numFeatures)
-    val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map {
-      case Row(p: Double, l: Double) => (p, l)
-    }
     // train multinomial logistic regression
     val lr = new LogisticRegressionWithLBFGS()
       .setIntercept(true)
@@ -189,8 +177,12 @@ class MultilayerPerceptronClassifierSuite
       lrModel.predict(data.rdd.map(p => OldVectors.fromML(p.features))).zip(data.rdd.map(_.label))
     // MLP's predictions should not differ a lot from LR's.
     val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels)
-    val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
-    assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100)
+    testTransformerByGlobalCheckFunc[(Double, Vector)](dataFrame, model, "prediction", "label") {
+      rows: Seq[Row] =>
+        val mlpPredictionAndLabels = rows.map(x => (x.getDouble(0), x.getDouble(1)))
+        val mlpMetrics = new MulticlassMetrics(sc.makeRDD(mlpPredictionAndLabels))
+        assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100)
+    }
   }
 
   test("read/write: MultilayerPerceptronClassifier") {

http://git-wip-us.apache.org/repos/asf/spark/blob/98a5c0a3/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 0d3adf9..49115c8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -28,12 +28,11 @@ import org.apache.spark.ml.classification.NaiveBayesSuite._
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg._
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 
-class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -56,13 +55,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
     bernoulliDataset = generateNaiveBayesInput(pi, theta, 100, seed, "bernoulli").toDF()
   }
 
-  def validatePrediction(predictionAndLabels: DataFrame): Unit = {
-    val numOfErrorPredictions = predictionAndLabels.collect().count {
+  def validatePrediction(predictionAndLabels: Seq[Row]): Unit = {
+    val numOfErrorPredictions = predictionAndLabels.filter {
       case Row(prediction: Double, label: Double) =>
         prediction != label
-    }
+    }.length
     // At least 80% of the predictions should be on.
-    assert(numOfErrorPredictions < predictionAndLabels.count() / 5)
+    assert(numOfErrorPredictions < predictionAndLabels.length / 5)
   }
 
   def validateModelFit(
@@ -92,10 +91,10 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
   }
 
   def validateProbabilities(
-      featureAndProbabilities: DataFrame,
+      featureAndProbabilities: Seq[Row],
       model: NaiveBayesModel,
       modelType: String): Unit = {
-    featureAndProbabilities.collect().foreach {
+    featureAndProbabilities.foreach {
       case Row(features: Vector, probability: Vector) =>
         assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
         val expected = modelType match {
@@ -154,15 +153,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
     val validationDataset =
       generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF()
 
-    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
-    validatePrediction(predictionAndLabels)
+    testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model,
+      "prediction", "label") { predictionAndLabels: Seq[Row] =>
+      validatePrediction(predictionAndLabels)
+    }
 
-    val featureAndProbabilities = model.transform(validationDataset)
-      .select("features", "probability")
-    validateProbabilities(featureAndProbabilities, model, "multinomial")
+    testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model,
+      "features", "probability") { featureAndProbabilities: Seq[Row] =>
+      validateProbabilities(featureAndProbabilities, model, "multinomial")
+    }
 
     ProbabilisticClassifierSuite.testPredictMethods[
-      Vector, NaiveBayesModel](model, testDataset)
+      Vector, NaiveBayesModel](this, model, testDataset)
   }
 
   test("Naive Bayes with weighted samples") {
@@ -210,15 +212,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
     val validationDataset =
       generateNaiveBayesInput(piArray, thetaArray, nPoints, 20, "bernoulli").toDF()
 
-    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
-    validatePrediction(predictionAndLabels)
+    testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model,
+      "prediction", "label") { predictionAndLabels: Seq[Row] =>
+      validatePrediction(predictionAndLabels)
+    }
 
-    val featureAndProbabilities = model.transform(validationDataset)
-      .select("features", "probability")
-    validateProbabilities(featureAndProbabilities, model, "bernoulli")
+    testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model,
+      "features", "probability") { featureAndProbabilities: Seq[Row] =>
+      validateProbabilities(featureAndProbabilities, model, "bernoulli")
+    }
 
     ProbabilisticClassifierSuite.testPredictMethods[
-      Vector, NaiveBayesModel](model, testDataset)
+      Vector, NaiveBayesModel](this, model, testDataset)
   }
 
   test("detect negative values") {

http://git-wip-us.apache.org/repos/asf/spark/blob/98a5c0a3/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 25bad59..11e8836 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -17,26 +17,24 @@
 
 package org.apache.spark.ml.classification
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.classification.LogisticRegressionSuite._
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.feature.StringIndexer
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
 import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.Metadata
 
-class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -85,10 +83,6 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
     val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol)
     assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3))
 
-    val ovaResults = transformedDataset.select("prediction", "label").rdd.map {
-      row => (row.getDouble(0), row.getDouble(1))
-    }
-
     val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses)
     lr.optimizer.setRegParam(0.1).setNumIterations(100)
 
@@ -97,8 +91,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
     // determine the #confusion matrix in each class.
     // bound how much error we allow compared to multinomial logistic regression.
     val expectedMetrics = new MulticlassMetrics(results)
-    val ovaMetrics = new MulticlassMetrics(ovaResults)
-    assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400)
+
+    testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), ovaModel,
+      "prediction", "label") { rows =>
+      val ovaResults = rows.map { row => (row.getDouble(0), row.getDouble(1)) }
+      val ovaMetrics = new MulticlassMetrics(sc.makeRDD(ovaResults))
+      assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400)
+    }
   }
 
   test("one-vs-rest: tuning parallelism does not change output") {

http://git-wip-us.apache.org/repos/asf/spark/blob/98a5c0a3/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index d649cea..1c8c982 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.classification
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.MLTest
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.sql.{Dataset, Row}
 
@@ -122,13 +123,15 @@ object ProbabilisticClassifierSuite {
   def testPredictMethods[
       FeaturesType,
       M <: ProbabilisticClassificationModel[FeaturesType, M]](
-    model: M, testData: Dataset[_]): Unit = {
+    mlTest: MLTest, model: M, testData: Dataset[_]): Unit = {
 
     val allColModel = model.copy(ParamMap.empty)
       .setRawPredictionCol("rawPredictionAll")
       .setProbabilityCol("probabilityAll")
       .setPredictionCol("predictionAll")
-    val allColResult = allColModel.transform(testData)
+
+    val allColResult = allColModel.transform(testData.select(allColModel.getFeaturesCol))
+      .select(allColModel.getFeaturesCol, "rawPredictionAll", "probabilityAll", "predictionAll")
 
     for (rawPredictionCol <- Seq("", "rawPredictionSingle")) {
       for (probabilityCol <- Seq("", "probabilitySingle")) {
@@ -138,22 +141,14 @@ object ProbabilisticClassifierSuite {
             .setProbabilityCol(probabilityCol)
             .setPredictionCol(predictionCol)
 
-          val result = newModel.transform(allColResult)
-
-          import org.apache.spark.sql.functions._
-
-          val resultRawPredictionCol =
-            if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol)
-          val resultProbabilityCol =
-            if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol)
-          val resultPredictionCol =
-            if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol)
+          import allColResult.sparkSession.implicits._
 
-          result.select(
-            resultRawPredictionCol, col("rawPredictionAll"),
-            resultProbabilityCol, col("probabilityAll"),
-            resultPredictionCol, col("predictionAll")
-          ).collect().foreach {
+          mlTest.testTransformer[(Vector, Vector, Vector, Double)](allColResult, newModel,
+            if (rawPredictionCol.isEmpty) "rawPredictionAll" else rawPredictionCol,
+            "rawPredictionAll",
+            if (probabilityCol.isEmpty) "probabilityAll" else probabilityCol, "probabilityAll",
+            if (predictionCol.isEmpty) "predictionAll" else predictionCol, "predictionAll"
+          ) {
             case Row(
               rawPredictionSingle: Vector, rawPredictionAll: Vector,
               probabilitySingle: Vector, probabilityAll: Vector,

http://git-wip-us.apache.org/repos/asf/spark/blob/98a5c0a3/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 2cca2e6..02a9d5c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -23,11 +23,10 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
@@ -35,8 +34,7 @@ import org.apache.spark.sql.{DataFrame, Row}
 /**
  * Test suite for [[RandomForestClassifier]].
  */
-class RandomForestClassifierSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
 
   import RandomForestClassifierSuite.compareAPIs
   import testImplicits._
@@ -143,11 +141,8 @@ class RandomForestClassifierSuite
 
     MLTestingUtils.checkCopyAndUids(rf, model)
 
-    val predictions = model.transform(df)
-      .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol)
-      .collect()
-
-    predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+    testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction",
+      "probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
       assert(pred === rawPred.argmax,
         s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
       val sum = rawPred.toArray.sum
@@ -155,8 +150,9 @@ class RandomForestClassifierSuite
         "probability prediction mismatch")
       assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
     }
+
     ProbabilisticClassifierSuite.testPredictMethods[
-      Vector, RandomForestClassificationModel](model, df)
+      Vector, RandomForestClassificationModel](this, model, df)
   }
 
   test("Fitting without numClasses in metadata") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org