You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/03/15 01:36:01 UTC

[2/2] spark git commit: [SPARK-22915][MLLIB] Streaming tests for spark.ml.feature, from N to Z

[SPARK-22915][MLLIB] Streaming tests for spark.ml.feature, from N to Z

# What changes were proposed in this pull request?

Adds structured streaming tests using testTransformer for these suites:

- NGramSuite
- NormalizerSuite
- OneHotEncoderEstimatorSuite
- OneHotEncoderSuite
- PCASuite
- PolynomialExpansionSuite
- QuantileDiscretizerSuite
- RFormulaSuite
- SQLTransformerSuite
- StandardScalerSuite
- StopWordsRemoverSuite
- StringIndexerSuite
- TokenizerSuite
- RegexTokenizerSuite
- VectorAssemblerSuite
- VectorIndexerSuite
- VectorSizeHintSuite
- VectorSlicerSuite
- Word2VecSuite

# How was this patch tested?

They are unit test.

Author: “attilapiros” <pi...@gmail.com>

Closes #20686 from attilapiros/SPARK-22915.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/279b3db8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/279b3db8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/279b3db8

Branch: refs/heads/master
Commit: 279b3db8970809104c30941254e57e3d62da5041
Parents: 1098933
Author: “attilapiros” <pi...@gmail.com>
Authored: Wed Mar 14 18:36:31 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed Mar 14 18:36:31 2018 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/NGramSuite.scala    |  23 +--
 .../spark/ml/feature/NormalizerSuite.scala      |  57 ++----
 .../feature/OneHotEncoderEstimatorSuite.scala   | 193 +++++++++---------
 .../spark/ml/feature/OneHotEncoderSuite.scala   | 124 ++++++-----
 .../org/apache/spark/ml/feature/PCASuite.scala  |  14 +-
 .../ml/feature/PolynomialExpansionSuite.scala   |  62 +++---
 .../ml/feature/QuantileDiscretizerSuite.scala   | 198 ++++++++++--------
 .../apache/spark/ml/feature/RFormulaSuite.scala | 158 +++++++-------
 .../spark/ml/feature/SQLTransformerSuite.scala  |  35 ++--
 .../spark/ml/feature/StandardScalerSuite.scala  |  33 +--
 .../ml/feature/StopWordsRemoverSuite.scala      |  37 ++--
 .../spark/ml/feature/StringIndexerSuite.scala   | 204 ++++++++++---------
 .../spark/ml/feature/TokenizerSuite.scala       |  30 +--
 .../spark/ml/feature/VectorIndexerSuite.scala   | 183 +++++++++--------
 .../spark/ml/feature/VectorSizeHintSuite.scala  |  88 +++++---
 .../spark/ml/feature/VectorSlicerSuite.scala    |  27 +--
 .../apache/spark/ml/feature/Word2VecSuite.scala |  28 +--
 .../scala/org/apache/spark/ml/util/MLTest.scala |  33 ++-
 18 files changed, 809 insertions(+), 718 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
index d4975c0..e5956ee 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -19,17 +19,15 @@ package org.apache.spark.ml.feature
 
 import scala.beans.BeanInfo
 
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.{DataFrame, Row}
+
 
 @BeanInfo
 case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
 
-class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class NGramSuite extends MLTest with DefaultReadWriteTest {
 
-  import org.apache.spark.ml.feature.NGramSuite._
   import testImplicits._
 
   test("default behavior yields bigram features") {
@@ -83,16 +81,11 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
       .setN(3)
     testDefaultReadWrite(t)
   }
-}
-
-object NGramSuite extends SparkFunSuite {
 
-  def testNGram(t: NGram, dataset: Dataset[_]): Unit = {
-    t.transform(dataset)
-      .select("nGrams", "wantedNGrams")
-      .collect()
-      .foreach { case Row(actualNGrams, wantedNGrams) =>
+  def testNGram(t: NGram, dataFrame: DataFrame): Unit = {
+    testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") {
+      case Row(actualNGrams : Seq[String], wantedNGrams: Seq[String]) =>
         assert(actualNGrams === wantedNGrams)
-      }
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index c75027f..eff57f1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -17,21 +17,17 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
 
 
-class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class NormalizerSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
   @transient var data: Array[Vector] = _
-  @transient var dataFrame: DataFrame = _
-  @transient var normalizer: Normalizer = _
   @transient var l1Normalized: Array[Vector] = _
   @transient var l2Normalized: Array[Vector] = _
 
@@ -62,49 +58,40 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
       Vectors.dense(0.897906166, 0.113419726, 0.42532397),
       Vectors.sparse(3, Seq())
     )
-
-    dataFrame = data.map(NormalizerSuite.FeatureData).toSeq.toDF()
-    normalizer = new Normalizer()
-      .setInputCol("features")
-      .setOutputCol("normalized_features")
-  }
-
-  def collectResult(result: DataFrame): Array[Vector] = {
-    result.select("normalized_features").collect().map {
-      case Row(features: Vector) => features
-    }
   }
 
-  def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
-    assert((lhs, rhs).zipped.forall {
+  def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = {
+    assert((lhs, rhs) match {
       case (v1: DenseVector, v2: DenseVector) => true
       case (v1: SparseVector, v2: SparseVector) => true
       case _ => false
     }, "The vector type should be preserved after normalization.")
   }
 
-  def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
-    assert((lhs, rhs).zipped.forall { (vector1, vector2) =>
-      vector1 ~== vector2 absTol 1E-5
-    }, "The vector value is not correct after normalization.")
+  def assertValues(lhs: Vector, rhs: Vector): Unit = {
+    assert(lhs ~== rhs absTol 1E-5, "The vector value is not correct after normalization.")
   }
 
   test("Normalization with default parameter") {
-    val result = collectResult(normalizer.transform(dataFrame))
-
-    assertTypeOfVector(data, result)
+    val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized")
+    val dataFrame: DataFrame = data.zip(l2Normalized).seq.toDF("features", "expected")
 
-    assertValues(result, l2Normalized)
+    testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") {
+      case Row(features: Vector, normalized: Vector, expected: Vector) =>
+        assertTypeOfVector(normalized, features)
+        assertValues(normalized, expected)
+    }
   }
 
   test("Normalization with setter") {
-    normalizer.setP(1)
+    val dataFrame: DataFrame = data.zip(l1Normalized).seq.toDF("features", "expected")
+    val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized").setP(1)
 
-    val result = collectResult(normalizer.transform(dataFrame))
-
-    assertTypeOfVector(data, result)
-
-    assertValues(result, l1Normalized)
+    testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") {
+      case Row(features: Vector, normalized: Vector, expected: Vector) =>
+        assertTypeOfVector(normalized, features)
+        assertValues(normalized, expected)
+    }
   }
 
   test("read/write") {
@@ -115,7 +102,3 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
     testDefaultReadWrite(t)
   }
 }
-
-private object NormalizerSuite {
-  case class FeatureData(features: Vector)
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
index 1d3f845..d549e13 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
@@ -17,18 +17,16 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
 import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.{Encoder, Row}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.types._
 
-class OneHotEncoderEstimatorSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class OneHotEncoderEstimatorSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -57,13 +55,10 @@ class OneHotEncoderEstimatorSuite
     assert(encoder.getDropLast === true)
     encoder.setDropLast(false)
     assert(encoder.getDropLast === false)
-
     val model = encoder.fit(df)
-    val encoded = model.transform(df)
-    encoded.select("output", "expected").rdd.map { r =>
-      (r.getAs[Vector](0), r.getAs[Vector](1))
-    }.collect().foreach { case (vec1, vec2) =>
-      assert(vec1 === vec2)
+    testTransformer[(Double, Vector)](df, model, "output", "expected") {
+      case Row(output: Vector, expected: Vector) =>
+        assert(output === expected)
     }
   }
 
@@ -87,11 +82,9 @@ class OneHotEncoderEstimatorSuite
       .setOutputCols(Array("output"))
 
     val model = encoder.fit(df)
-    val encoded = model.transform(df)
-    encoded.select("output", "expected").rdd.map { r =>
-      (r.getAs[Vector](0), r.getAs[Vector](1))
-    }.collect().foreach { case (vec1, vec2) =>
-      assert(vec1 === vec2)
+    testTransformer[(Double, Vector)](df, model, "output", "expected") {
+      case Row(output: Vector, expected: Vector) =>
+        assert(output === expected)
     }
   }
 
@@ -103,11 +96,12 @@ class OneHotEncoderEstimatorSuite
       .setInputCols(Array("size"))
       .setOutputCols(Array("encoded"))
     val model = encoder.fit(df)
-    val output = model.transform(df)
-    val group = AttributeGroup.fromStructField(output.schema("encoded"))
-    assert(group.size === 2)
-    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
-    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
+    testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows =>
+        val group = AttributeGroup.fromStructField(rows.head.schema("encoded"))
+        assert(group.size === 2)
+        assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
+        assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
+    }
   }
 
   test("input column without ML attribute") {
@@ -116,11 +110,12 @@ class OneHotEncoderEstimatorSuite
       .setInputCols(Array("index"))
       .setOutputCols(Array("encoded"))
     val model = encoder.fit(df)
-    val output = model.transform(df)
-    val group = AttributeGroup.fromStructField(output.schema("encoded"))
-    assert(group.size === 2)
-    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
-    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
+    testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows =>
+      val group = AttributeGroup.fromStructField(rows.head.schema("encoded"))
+      assert(group.size === 2)
+      assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
+      assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
+    }
   }
 
   test("read/write") {
@@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite
 
     val df = spark.createDataFrame(sc.parallelize(data), schema)
 
-    val dfWithTypes = df
-      .withColumn("shortInput", df("input").cast(ShortType))
-      .withColumn("longInput", df("input").cast(LongType))
-      .withColumn("intInput", df("input").cast(IntegerType))
-      .withColumn("floatInput", df("input").cast(FloatType))
-      .withColumn("decimalInput", df("input").cast(DecimalType(10, 0)))
-
-    val cols = Array("input", "shortInput", "longInput", "intInput",
-      "floatInput", "decimalInput")
-    for (col <- cols) {
-      val encoder = new OneHotEncoderEstimator()
-        .setInputCols(Array(col))
+    class NumericTypeWithEncoder[A](val numericType: NumericType)
+      (implicit val encoder: Encoder[(A, Vector)])
+
+    val types = Seq(
+      new NumericTypeWithEncoder[Short](ShortType),
+      new NumericTypeWithEncoder[Long](LongType),
+      new NumericTypeWithEncoder[Int](IntegerType),
+      new NumericTypeWithEncoder[Float](FloatType),
+      new NumericTypeWithEncoder[Byte](ByteType),
+      new NumericTypeWithEncoder[Double](DoubleType),
+      new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()))
+
+    for (t <- types) {
+      val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected"))
+      val estimator = new OneHotEncoderEstimator()
+        .setInputCols(Array("input"))
         .setOutputCols(Array("output"))
         .setDropLast(false)
 
-      val model = encoder.fit(dfWithTypes)
-      val encoded = model.transform(dfWithTypes)
-
-      encoded.select("output", "expected").rdd.map { r =>
-        (r.getAs[Vector](0), r.getAs[Vector](1))
-      }.collect().foreach { case (vec1, vec2) =>
-        assert(vec1 === vec2)
-      }
+      val model = estimator.fit(dfWithTypes)
+      testTransformer(dfWithTypes, model, "output", "expected") {
+        case Row(output: Vector, expected: Vector) =>
+          assert(output === expected)
+      }(t.encoder)
     }
   }
 
@@ -202,12 +198,16 @@ class OneHotEncoderEstimatorSuite
     assert(encoder.getDropLast === false)
 
     val model = encoder.fit(df)
-    val encoded = model.transform(df)
-    encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r =>
-      (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3))
-    }.collect().foreach { case (vec1, vec2, vec3, vec4) =>
-      assert(vec1 === vec2)
-      assert(vec3 === vec4)
+    testTransformer[(Double, Vector, Double, Vector)](
+      df,
+      model,
+      "output1",
+      "output2",
+      "expected1",
+      "expected2") {
+      case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) =>
+        assert(output1 === expected1)
+        assert(output2 === expected2)
     }
   }
 
@@ -233,12 +233,16 @@ class OneHotEncoderEstimatorSuite
       .setOutputCols(Array("output1", "output2"))
 
     val model = encoder.fit(df)
-    val encoded = model.transform(df)
-    encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r =>
-      (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3))
-    }.collect().foreach { case (vec1, vec2, vec3, vec4) =>
-      assert(vec1 === vec2)
-      assert(vec3 === vec4)
+    testTransformer[(Double, Vector, Double, Vector)](
+      df,
+      model,
+      "output1",
+      "output2",
+      "expected1",
+      "expected2") {
+      case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) =>
+        assert(output1 === expected1)
+        assert(output2 === expected2)
     }
   }
 
@@ -253,10 +257,12 @@ class OneHotEncoderEstimatorSuite
       .setOutputCols(Array("encoded"))
 
     val model = encoder.fit(trainingDF)
-    val err = intercept[SparkException] {
-      model.transform(testDF).show
-    }
-    err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
+    testTransformerByInterceptingException[(Int, Int)](
+      testDF,
+      model,
+      expectedMessagePart = "Unseen value: 3.0. To handle unseen values",
+      firstResultCol = "encoded")
+
   }
 
   test("Can't transform on negative input") {
@@ -268,10 +274,11 @@ class OneHotEncoderEstimatorSuite
       .setOutputCols(Array("encoded"))
 
     val model = encoder.fit(trainingDF)
-    val err = intercept[SparkException] {
-      model.transform(testDF).collect()
-    }
-    err.getMessage.contains("Negative value: -1.0. Input can't be negative")
+    testTransformerByInterceptingException[(Int, Int)](
+      testDF,
+      model,
+      expectedMessagePart = "Negative value: -1.0. Input can't be negative",
+      firstResultCol = "encoded")
   }
 
   test("Keep on invalid values: dropLast = false") {
@@ -295,11 +302,9 @@ class OneHotEncoderEstimatorSuite
       .setDropLast(false)
 
     val model = encoder.fit(trainingDF)
-    val encoded = model.transform(testDF)
-    encoded.select("output", "expected").rdd.map { r =>
-      (r.getAs[Vector](0), r.getAs[Vector](1))
-    }.collect().foreach { case (vec1, vec2) =>
-      assert(vec1 === vec2)
+    testTransformer[(Double, Vector)](testDF, model, "output", "expected") {
+      case Row(output: Vector, expected: Vector) =>
+        assert(output === expected)
     }
   }
 
@@ -324,11 +329,9 @@ class OneHotEncoderEstimatorSuite
       .setDropLast(true)
 
     val model = encoder.fit(trainingDF)
-    val encoded = model.transform(testDF)
-    encoded.select("output", "expected").rdd.map { r =>
-      (r.getAs[Vector](0), r.getAs[Vector](1))
-    }.collect().foreach { case (vec1, vec2) =>
-      assert(vec1 === vec2)
+    testTransformer[(Double, Vector)](testDF, model, "output", "expected") {
+      case Row(output: Vector, expected: Vector) =>
+        assert(output === expected)
     }
   }
 
@@ -355,19 +358,15 @@ class OneHotEncoderEstimatorSuite
     val model = encoder.fit(df)
 
     model.setDropLast(false)
-    val encoded1 = model.transform(df)
-    encoded1.select("output", "expected1").rdd.map { r =>
-      (r.getAs[Vector](0), r.getAs[Vector](1))
-    }.collect().foreach { case (vec1, vec2) =>
-      assert(vec1 === vec2)
+    testTransformer[(Double, Vector, Vector)](df, model, "output", "expected1") {
+      case Row(output: Vector, expected1: Vector) =>
+        assert(output === expected1)
     }
 
     model.setDropLast(true)
-    val encoded2 = model.transform(df)
-    encoded2.select("output", "expected2").rdd.map { r =>
-      (r.getAs[Vector](0), r.getAs[Vector](1))
-    }.collect().foreach { case (vec1, vec2) =>
-      assert(vec1 === vec2)
+    testTransformer[(Double, Vector, Vector)](df, model, "output", "expected2") {
+      case Row(output: Vector, expected2: Vector) =>
+        assert(output === expected2)
     }
   }
 
@@ -392,13 +391,14 @@ class OneHotEncoderEstimatorSuite
     val model = encoder.fit(trainingDF)
     model.setHandleInvalid("error")
 
-    val err = intercept[SparkException] {
-      model.transform(testDF).collect()
-    }
-    err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
+    testTransformerByInterceptingException[(Double, Vector)](
+      testDF,
+      model,
+      expectedMessagePart = "Unseen value: 3.0. To handle unseen values",
+      firstResultCol = "output")
 
     model.setHandleInvalid("keep")
-    model.transform(testDF).collect()
+    testTransformerByGlobalCheckFunc[(Double, Vector)](testDF, model, "output") { _ => }
   }
 
   test("Transforming on mismatched attributes") {
@@ -413,9 +413,10 @@ class OneHotEncoderEstimatorSuite
     val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large")
     val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size")
       .select(col("size").as("size", testAttr.toMetadata()))
-    val err = intercept[Exception] {
-      model.transform(testDF).collect()
-    }
-    err.getMessage.contains("OneHotEncoderModel expected 2 categorical values")
+    testTransformerByInterceptingException[(Double)](
+      testDF,
+      model,
+      expectedMessagePart = "OneHotEncoderModel expected 2 categorical values",
+      firstResultCol = "encoded")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index c44c681..41b32b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -17,18 +17,18 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
 import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.{DataFrame, Encoder, Row}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.types._
 
 class OneHotEncoderSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+  extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -54,16 +54,19 @@ class OneHotEncoderSuite
     assert(encoder.getDropLast === true)
     encoder.setDropLast(false)
     assert(encoder.getDropLast === false)
-    val encoded = encoder.transform(transformed)
-
-    val output = encoded.select("id", "labelVec").rdd.map { r =>
-      val vec = r.getAs[Vector](1)
-      (r.getInt(0), vec(0), vec(1), vec(2))
-    }.collect().toSet
-    // a -> 0, b -> 2, c -> 1
-    val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
-      (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
-    assert(output === expected)
+    val expected = Seq(
+      (0, Vectors.sparse(3, Seq((0, 1.0)))),
+      (1, Vectors.sparse(3, Seq((2, 1.0)))),
+      (2, Vectors.sparse(3, Seq((1, 1.0)))),
+      (3, Vectors.sparse(3, Seq((0, 1.0)))),
+      (4, Vectors.sparse(3, Seq((0, 1.0)))),
+      (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected")
+
+    val withExpected = transformed.join(expected, "id")
+    testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") {
+      case Row(output: Vector, expected: Vector) =>
+        assert(output === expected)
+    }
   }
 
   test("OneHotEncoder dropLast = true") {
@@ -71,16 +74,19 @@ class OneHotEncoderSuite
     val encoder = new OneHotEncoder()
       .setInputCol("labelIndex")
       .setOutputCol("labelVec")
-    val encoded = encoder.transform(transformed)
-
-    val output = encoded.select("id", "labelVec").rdd.map { r =>
-      val vec = r.getAs[Vector](1)
-      (r.getInt(0), vec(0), vec(1))
-    }.collect().toSet
-    // a -> 0, b -> 2, c -> 1
-    val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0),
-      (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0))
-    assert(output === expected)
+    val expected = Seq(
+      (0, Vectors.sparse(2, Seq((0, 1.0)))),
+      (1, Vectors.sparse(2, Seq())),
+      (2, Vectors.sparse(2, Seq((1, 1.0)))),
+      (3, Vectors.sparse(2, Seq((0, 1.0)))),
+      (4, Vectors.sparse(2, Seq((0, 1.0)))),
+      (5, Vectors.sparse(2, Seq((1, 1.0))))).toDF("id", "expected")
+
+    val withExpected = transformed.join(expected, "id")
+    testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") {
+      case Row(output: Vector, expected: Vector) =>
+        assert(output === expected)
+    }
   }
 
   test("input column with ML attribute") {
@@ -90,20 +96,22 @@ class OneHotEncoderSuite
     val encoder = new OneHotEncoder()
       .setInputCol("size")
       .setOutputCol("encoded")
-    val output = encoder.transform(df)
-    val group = AttributeGroup.fromStructField(output.schema("encoded"))
-    assert(group.size === 2)
-    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
-    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
+    testTransformerByGlobalCheckFunc[(Double)](df, encoder, "encoded") { rows =>
+      val group = AttributeGroup.fromStructField(rows.head.schema("encoded"))
+      assert(group.size === 2)
+      assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
+      assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
+    }
   }
 
+
   test("input column without ML attribute") {
     val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index")
     val encoder = new OneHotEncoder()
       .setInputCol("index")
       .setOutputCol("encoded")
-    val output = encoder.transform(df)
-    val group = AttributeGroup.fromStructField(output.schema("encoded"))
+    val rows = encoder.transform(df).select("encoded").collect()
+    val group = AttributeGroup.fromStructField(rows.head.schema("encoded"))
     assert(group.size === 2)
     assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
     assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
@@ -119,29 +127,41 @@ class OneHotEncoderSuite
 
   test("OneHotEncoder with varying types") {
     val df = stringIndexed()
-    val dfWithTypes = df
-      .withColumn("shortLabel", df("labelIndex").cast(ShortType))
-      .withColumn("longLabel", df("labelIndex").cast(LongType))
-      .withColumn("intLabel", df("labelIndex").cast(IntegerType))
-      .withColumn("floatLabel", df("labelIndex").cast(FloatType))
-      .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0)))
-    val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel",
-      "floatLabel", "decimalLabel")
-    for (col <- cols) {
+    val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
+    val expected = Seq(
+      (0, Vectors.sparse(3, Seq((0, 1.0)))),
+      (1, Vectors.sparse(3, Seq((2, 1.0)))),
+      (2, Vectors.sparse(3, Seq((1, 1.0)))),
+      (3, Vectors.sparse(3, Seq((0, 1.0)))),
+      (4, Vectors.sparse(3, Seq((0, 1.0)))),
+      (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected")
+
+    val withExpected = df.join(expected, "id")
+
+    class NumericTypeWithEncoder[A](val numericType: NumericType)
+       (implicit val encoder: Encoder[(A, Vector)])
+
+    val types = Seq(
+      new NumericTypeWithEncoder[Short](ShortType),
+      new NumericTypeWithEncoder[Long](LongType),
+      new NumericTypeWithEncoder[Int](IntegerType),
+      new NumericTypeWithEncoder[Float](FloatType),
+      new NumericTypeWithEncoder[Byte](ByteType),
+      new NumericTypeWithEncoder[Double](DoubleType),
+      new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()))
+
+    for (t <- types) {
+      val dfWithTypes = withExpected.select(col("labelIndex")
+        .cast(t.numericType).as("labelIndex", attr.toMetadata()), col("expected"))
       val encoder = new OneHotEncoder()
-        .setInputCol(col)
+        .setInputCol("labelIndex")
         .setOutputCol("labelVec")
         .setDropLast(false)
-      val encoded = encoder.transform(dfWithTypes)
-
-      val output = encoded.select("id", "labelVec").rdd.map { r =>
-        val vec = r.getAs[Vector](1)
-        (r.getInt(0), vec(0), vec(1), vec(2))
-      }.collect().toSet
-      // a -> 0, b -> 2, c -> 1
-      val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
-        (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
-      assert(output === expected)
+
+      testTransformer(dfWithTypes, encoder, "labelVec", "expected") {
+        case Row(output: Vector, expected: Vector) =>
+          assert(output === expected)
+      }(t.encoder)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index 3067a52..531b1d7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -17,17 +17,15 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg._
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
 import org.apache.spark.mllib.linalg.distributed.RowMatrix
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Row
 
-class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class PCASuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -62,10 +60,10 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
     val pcaModel = pca.fit(df)
 
     MLTestingUtils.checkCopyAndUids(pca, pcaModel)
-
-    pcaModel.transform(df).select("pca_features", "expected").collect().foreach {
-      case Row(x: Vector, y: Vector) =>
-        assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
+    testTransformer[(Vector, Vector)](df, pcaModel, "pca_features", "expected") {
+      case Row(result: Vector, expected: Vector) =>
+        assert(result ~== expected absTol 1e-5,
+          "Transformed vector is different with expected vector.")
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index e4b0ddf..0be7aa6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -17,18 +17,13 @@
 
 package org.apache.spark.ml.feature
 
-import org.scalatest.exceptions.TestFailedException
-
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Row
 
-class PolynomialExpansionSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class PolynomialExpansionSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -60,6 +55,18 @@ class PolynomialExpansionSuite
       -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0),
     Vectors.sparse(19, Array.empty, Array.empty))
 
+  def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = {
+    assert((lhs, rhs) match {
+      case (v1: DenseVector, v2: DenseVector) => true
+      case (v1: SparseVector, v2: SparseVector) => true
+      case _ => false
+    }, "The vector type should be preserved after polynomial expansion.")
+  }
+
+  def assertValues(lhs: Vector, rhs: Vector): Unit = {
+    assert(lhs ~== rhs absTol 1e-1, "The vector value is not correct after polynomial expansion.")
+  }
+
   test("Polynomial expansion with default parameter") {
     val df = data.zip(twoDegreeExpansion).toSeq.toDF("features", "expected")
 
@@ -67,13 +74,10 @@ class PolynomialExpansionSuite
       .setInputCol("features")
       .setOutputCol("polyFeatures")
 
-    polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach {
-      case Row(expanded: DenseVector, expected: DenseVector) =>
-        assert(expanded ~== expected absTol 1e-1)
-      case Row(expanded: SparseVector, expected: SparseVector) =>
-        assert(expanded ~== expected absTol 1e-1)
-      case _ =>
-        throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
+    testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") {
+      case Row(expanded: Vector, expected: Vector) =>
+        assertTypeOfVector(expanded, expected)
+        assertValues(expanded, expected)
     }
   }
 
@@ -85,13 +89,10 @@ class PolynomialExpansionSuite
       .setOutputCol("polyFeatures")
       .setDegree(3)
 
-    polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach {
-      case Row(expanded: DenseVector, expected: DenseVector) =>
-        assert(expanded ~== expected absTol 1e-1)
-      case Row(expanded: SparseVector, expected: SparseVector) =>
-        assert(expanded ~== expected absTol 1e-1)
-      case _ =>
-        throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
+    testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") {
+      case Row(expanded: Vector, expected: Vector) =>
+        assertTypeOfVector(expanded, expected)
+        assertValues(expanded, expected)
     }
   }
 
@@ -103,11 +104,9 @@ class PolynomialExpansionSuite
       .setOutputCol("polyFeatures")
       .setDegree(1)
 
-    polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach {
+    testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") {
       case Row(expanded: Vector, expected: Vector) =>
-        assert(expanded ~== expected absTol 1e-1)
-      case _ =>
-        throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
+        assertValues(expanded, expected)
     }
   }
 
@@ -133,12 +132,13 @@ class PolynomialExpansionSuite
       .setOutputCol("polyFeatures")
 
     for (i <- Seq(10, 11)) {
-      val transformed = t.setDegree(i)
-        .transform(df)
-        .select(s"expectedPoly${i}size", "polyFeatures")
-        .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size }
-
-      assert(transformed.collect.forall(identity))
+      testTransformer[(Vector, Int, Int)](
+        df,
+        t.setDegree(i),
+        s"expectedPoly${i}size",
+        "polyFeatures") { case Row(size: Int, expected: Vector) =>
+            assert(size === expected.size)
+      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 6c36379..b009038 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -17,15 +17,11 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.Pipeline
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
 import org.apache.spark.sql._
-import org.apache.spark.sql.functions.udf
 
-class QuantileDiscretizerSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -40,19 +36,19 @@ class QuantileDiscretizerSuite
       .setInputCol("input")
       .setOutputCol("result")
       .setNumBuckets(numBuckets)
-    val result = discretizer.fit(df).transform(df)
-
-    val observedNumBuckets = result.select("result").distinct.count
-    assert(observedNumBuckets === numBuckets,
-      "Observed number of buckets does not equal expected number of buckets.")
+    val model = discretizer.fit(df)
 
-    val relativeError = discretizer.getRelativeError
-    val isGoodBucket = udf {
-      (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize)
+    testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows =>
+      val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result")
+      val observedNumBuckets = result.select("result").distinct.count
+      assert(observedNumBuckets === numBuckets,
+        "Observed number of buckets does not equal expected number of buckets.")
+      val relativeError = discretizer.getRelativeError
+      val numGoodBuckets = result.groupBy("result").count
+        .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}").count
+      assert(numGoodBuckets === numBuckets,
+        "Bucket sizes are not within expected relative error tolerance.")
     }
-    val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count
-    assert(numGoodBuckets === numBuckets,
-      "Bucket sizes are not within expected relative error tolerance.")
   }
 
   test("Test on data with high proportion of duplicated values") {
@@ -67,11 +63,14 @@ class QuantileDiscretizerSuite
       .setInputCol("input")
       .setOutputCol("result")
       .setNumBuckets(numBuckets)
-    val result = discretizer.fit(df).transform(df)
-    val observedNumBuckets = result.select("result").distinct.count
-    assert(observedNumBuckets == expectedNumBuckets,
-      s"Observed number of buckets are not correct." +
-        s" Expected $expectedNumBuckets but found $observedNumBuckets")
+    val model = discretizer.fit(df)
+    testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows =>
+      val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result")
+      val observedNumBuckets = result.select("result").distinct.count
+      assert(observedNumBuckets == expectedNumBuckets,
+        s"Observed number of buckets are not correct." +
+          s" Expected $expectedNumBuckets but found $observedNumBuckets")
+    }
   }
 
   test("Test transform on data with NaN value") {
@@ -90,17 +89,20 @@ class QuantileDiscretizerSuite
 
     withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") {
       val dataFrame: DataFrame = validData.toSeq.toDF("input")
-      intercept[SparkException] {
-        discretizer.fit(dataFrame).transform(dataFrame).collect()
-      }
+      val model = discretizer.fit(dataFrame)
+      testTransformerByInterceptingException[(Double)](
+        dataFrame,
+        model,
+        expectedMessagePart = "Bucketizer encountered NaN value.",
+        firstResultCol = "result")
     }
 
     List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{
       case(u, v) =>
         discretizer.setHandleInvalid(u)
         val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected")
-        val result = discretizer.fit(dataFrame).transform(dataFrame)
-        result.select("result", "expected").collect().foreach {
+        val model = discretizer.fit(dataFrame)
+        testTransformer[(Double, Double)](dataFrame, model, "result", "expected") {
           case Row(x: Double, y: Double) =>
             assert(x === y,
               s"The feature value is not correct after bucketing.  Expected $y but found $x")
@@ -119,14 +121,17 @@ class QuantileDiscretizerSuite
       .setOutputCol("result")
       .setNumBuckets(5)
 
-    val result = discretizer.fit(trainDF).transform(testDF)
-    val firstBucketSize = result.filter(result("result") === 0.0).count
-    val lastBucketSize = result.filter(result("result") === 4.0).count
+    val model = discretizer.fit(trainDF)
+    testTransformerByGlobalCheckFunc[(Double)](testDF, model, "result") { rows =>
+      val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result")
+      val firstBucketSize = result.filter(result("result") === 0.0).count
+      val lastBucketSize = result.filter(result("result") === 4.0).count
 
-    assert(firstBucketSize === 30L,
-      s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.")
-    assert(lastBucketSize === 31L,
-      s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.")
+      assert(firstBucketSize === 30L,
+        s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.")
+      assert(lastBucketSize === 31L,
+        s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.")
+    }
   }
 
   test("read/write") {
@@ -167,21 +172,24 @@ class QuantileDiscretizerSuite
       .setInputCols(Array("input1", "input2"))
       .setOutputCols(Array("result1", "result2"))
       .setNumBuckets(numBuckets)
-    val result = discretizer.fit(df).transform(df)
-
-    val relativeError = discretizer.getRelativeError
-    val isGoodBucket = udf {
-      (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize)
-    }
-
-    for (i <- 1 to 2) {
-      val observedNumBuckets = result.select("result" + i).distinct.count
-      assert(observedNumBuckets === numBuckets,
-        "Observed number of buckets does not equal expected number of buckets.")
-
-      val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count
-      assert(numGoodBuckets === numBuckets,
-        "Bucket sizes are not within expected relative error tolerance.")
+    val model = discretizer.fit(df)
+    testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows =>
+      val result =
+        rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2")
+      val relativeError = discretizer.getRelativeError
+      for (i <- 1 to 2) {
+        val observedNumBuckets = result.select("result" + i).distinct.count
+        assert(observedNumBuckets === numBuckets,
+          "Observed number of buckets does not equal expected number of buckets.")
+
+        val numGoodBuckets = result
+          .groupBy("result" + i)
+          .count
+          .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}")
+          .count
+        assert(numGoodBuckets === numBuckets,
+          "Bucket sizes are not within expected relative error tolerance.")
+      }
     }
   }
 
@@ -198,12 +206,16 @@ class QuantileDiscretizerSuite
       .setInputCols(Array("input1", "input2"))
       .setOutputCols(Array("result1", "result2"))
       .setNumBuckets(numBuckets)
-    val result = discretizer.fit(df).transform(df)
-    for (i <- 1 to 2) {
-      val observedNumBuckets = result.select("result" + i).distinct.count
-      assert(observedNumBuckets == expectedNumBucket,
-        s"Observed number of buckets are not correct." +
-          s" Expected $expectedNumBucket but found ($observedNumBuckets")
+    val model = discretizer.fit(df)
+    testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows =>
+      val result =
+        rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2")
+      for (i <- 1 to 2) {
+        val observedNumBuckets = result.select("result" + i).distinct.count
+        assert(observedNumBuckets == expectedNumBucket,
+          s"Observed number of buckets are not correct." +
+            s" Expected $expectedNumBucket but found ($observedNumBuckets")
+      }
     }
   }
 
@@ -226,9 +238,12 @@ class QuantileDiscretizerSuite
 
     withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") {
       val dataFrame: DataFrame = validData1.zip(validData2).toSeq.toDF("input1", "input2")
-      intercept[SparkException] {
-        discretizer.fit(dataFrame).transform(dataFrame).collect()
-      }
+      val model = discretizer.fit(dataFrame)
+      testTransformerByInterceptingException[(Double, Double)](
+        dataFrame,
+        model,
+        expectedMessagePart = "Bucketizer encountered NaN value.",
+        firstResultCol = "result1")
     }
 
     List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1, expectedSkip2)).foreach {
@@ -237,8 +252,14 @@ class QuantileDiscretizerSuite
         val dataFrame: DataFrame = validData1.zip(validData2).zip(v).zip(w).map {
           case (((a, b), c), d) => (a, b, c, d)
         }.toSeq.toDF("input1", "input2", "expected1", "expected2")
-        val result = discretizer.fit(dataFrame).transform(dataFrame)
-        result.select("result1", "expected1", "result2", "expected2").collect().foreach {
+        val model = discretizer.fit(dataFrame)
+        testTransformer[(Double, Double, Double, Double)](
+          dataFrame,
+          model,
+          "result1",
+          "expected1",
+          "result2",
+          "expected2") {
           case Row(x: Double, y: Double, z: Double, w: Double) =>
             assert(x === y && w === z)
         }
@@ -270,9 +291,16 @@ class QuantileDiscretizerSuite
       .setOutputCols(Array("result1", "result2", "result3"))
       .setNumBucketsArray(numBucketsArray)
 
-    discretizer.fit(df).transform(df).
-      select("result1", "expected1", "result2", "expected2", "result3", "expected3")
-      .collect().foreach {
+    val model = discretizer.fit(df)
+    testTransformer[(Double, Double, Double, Double, Double, Double)](
+      df,
+      model,
+      "result1",
+      "expected1",
+      "result2",
+      "expected2",
+      "result3",
+      "expected3") {
       case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: Double) =>
         assert(r1 === e1,
           s"The result value is not correct after bucketing. Expected $e1 but found $r1")
@@ -324,20 +352,16 @@ class QuantileDiscretizerSuite
       .setStages(Array(discretizerForCol1, discretizerForCol2, discretizerForCol3))
       .fit(df)
 
-    val resultForMultiCols = plForMultiCols.transform(df)
-      .select("result1", "result2", "result3")
-      .collect()
-
-    val resultForSingleCol = plForSingleCol.transform(df)
-      .select("result1", "result2", "result3")
-      .collect()
+    val expected = plForSingleCol.transform(df).select("result1", "result2", "result3").collect()
 
-    resultForSingleCol.zip(resultForMultiCols).foreach {
-      case (rowForSingle, rowForMultiCols) =>
-        assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) &&
-          rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) &&
-          rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2))
-    }
+    testTransformerByGlobalCheckFunc[(Double, Double, Double)](
+      df,
+      plForMultiCols,
+      "result1",
+      "result2",
+      "result3") { rows =>
+        assert(rows === expected)
+      }
   }
 
   test("Multiple Columns: Comparing setting numBuckets with setting numBucketsArray " +
@@ -364,18 +388,16 @@ class QuantileDiscretizerSuite
       .setOutputCols(Array("result1", "result2", "result3"))
       .setNumBucketsArray(Array(10, 10, 10))
 
-    val result1 = discretizerSingleNumBuckets.fit(df).transform(df)
-      .select("result1", "result2", "result3")
-      .collect()
-    val result2 = discretizerNumBucketsArray.fit(df).transform(df)
-      .select("result1", "result2", "result3")
-      .collect()
-
-    result1.zip(result2).foreach {
-      case (row1, row2) =>
-        assert(row1.getDouble(0) == row2.getDouble(0) &&
-          row1.getDouble(1) == row2.getDouble(1) &&
-          row1.getDouble(2) == row2.getDouble(2))
+    val model = discretizerSingleNumBuckets.fit(df)
+    val expected = model.transform(df).select("result1", "result2", "result3").collect()
+
+    testTransformerByGlobalCheckFunc[(Double, Double, Double)](
+      df,
+      discretizerNumBucketsArray.fit(df),
+      "result1",
+      "result2",
+      "result3") { rows =>
+      assert(rows === expected)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index bfe38d3..27d570f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkException
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
@@ -32,10 +31,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
   def testRFormulaTransform[A: Encoder](
       dataframe: DataFrame,
       formulaModel: RFormulaModel,
-      expected: DataFrame): Unit = {
+      expected: DataFrame,
+      expectedAttributes: AttributeGroup*): Unit = {
+    val resultSchema = formulaModel.transformSchema(dataframe.schema)
+    assert(resultSchema.json === expected.schema.json)
+    assert(resultSchema === expected.schema)
     val (first +: rest) = expected.schema.fieldNames.toSeq
     val expectedRows = expected.collect()
     testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows =>
+      assert(rows.head.schema.toString() == resultSchema.toString())
+      for (expectedAttributeGroup <- expectedAttributes) {
+        val attributeGroup =
+          AttributeGroup.fromStructField(rows.head.schema(expectedAttributeGroup.name))
+        assert(attributeGroup === expectedAttributeGroup)
+      }
       assert(rows === expectedRows)
     }
   }
@@ -49,15 +58,10 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
     val model = formula.fit(original)
     MLTestingUtils.checkCopyAndUids(formula, model)
-    val result = model.transform(original)
-    val resultSchema = model.transformSchema(original.schema)
     val expected = Seq(
       (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
       (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)
     ).toDF("id", "v1", "v2", "features", "label")
-    // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
-    assert(result.schema.toString == resultSchema.toString)
-    assert(resultSchema == expected.schema)
     testRFormulaTransform[(Int, Double, Double)](original, model, expected)
   }
 
@@ -73,9 +77,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
     val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y")
     val model = formula.fit(original)
+    val expected = Seq(
+      (0, 1.0, Vectors.dense(0.0)),
+      (2, 2.0, Vectors.dense(2.0))
+    ).toDF("x", "y", "features")
     val resultSchema = model.transformSchema(original.schema)
     assert(resultSchema.length == 3)
-    assert(resultSchema.toString == model.transform(original).schema.toString)
+    testRFormulaTransform[(Int, Double)](original, model, expected)
   }
 
   test("label column already exists but forceIndexLabel was set with true") {
@@ -93,9 +101,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     intercept[IllegalArgumentException] {
       model.transformSchema(original.schema)
     }
-    intercept[IllegalArgumentException] {
-      model.transform(original)
-    }
+    testTransformerByInterceptingException[(Int, Boolean)](
+      original,
+      model,
+      "Label column already exists and is not of type NumericType.",
+      "x")
   }
 
   test("allow missing label column for test datasets") {
@@ -105,21 +115,22 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     val resultSchema = model.transformSchema(original.schema)
     assert(resultSchema.length == 3)
     assert(!resultSchema.exists(_.name == "label"))
-    assert(resultSchema.toString == model.transform(original).schema.toString)
+    val expected = Seq(
+      (0, 1.0, Vectors.dense(0.0)),
+      (2, 2.0, Vectors.dense(2.0))
+    ).toDF("x", "_not_y", "features")
+    testRFormulaTransform[(Int, Double)](original, model, expected)
   }
 
   test("allow empty label") {
     val original = Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)).toDF("id", "a", "b")
     val formula = new RFormula().setFormula("~ a + b")
     val model = formula.fit(original)
-    val result = model.transform(original)
-    val resultSchema = model.transformSchema(original.schema)
     val expected = Seq(
       (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)),
       (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)),
       (7, 8.0, 9.0, Vectors.dense(8.0, 9.0))
     ).toDF("id", "a", "b", "features")
-    assert(result.schema.toString == resultSchema.toString)
     testRFormulaTransform[(Int, Double, Double)](original, model, expected)
   }
 
@@ -128,15 +139,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
       .toDF("id", "a", "b")
     val model = formula.fit(original)
-    val result = model.transform(original)
-    val resultSchema = model.transformSchema(original.schema)
     val expected = Seq(
         (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
         (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
         (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
         (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
       ).toDF("id", "a", "b", "features", "label")
-    assert(result.schema.toString == resultSchema.toString)
     testRFormulaTransform[(Int, String, Int)](original, model, expected)
   }
 
@@ -175,9 +183,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     var idx = 0
     for (orderType <- StringIndexer.supportedStringOrderType) {
       val model = formula.setStringIndexerOrderType(orderType).fit(original)
-      val result = model.transform(original)
-      val resultSchema = model.transformSchema(original.schema)
-      assert(result.schema.toString == resultSchema.toString)
       testRFormulaTransform[(Int, String, Int)](original, model, expected(idx))
       idx += 1
     }
@@ -218,9 +223,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     ).toDF("id", "a", "b", "features", "label")
 
     val model = formula.fit(original)
-    val result = model.transform(original)
-    val resultSchema = model.transformSchema(original.schema)
-    assert(result.schema.toString == resultSchema.toString)
     testRFormulaTransform[(Int, String, Int)](original, model, expected)
   }
 
@@ -254,19 +256,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     val formula1 = new RFormula().setFormula("id ~ a + b + c - 1")
       .setStringIndexerOrderType(StringIndexer.alphabetDesc)
     val model1 = formula1.fit(original)
-    val result1 = model1.transform(original)
-    val resultSchema1 = model1.transformSchema(original.schema)
-    // Note the column order is different between R and Spark.
-    val expected1 = Seq(
-      (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0),
-      (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0),
-      (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0),
-      (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0)
-    ).toDF("id", "a", "b", "c", "features", "label")
-    assert(result1.schema.toString == resultSchema1.toString)
-    testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1)
-
-    val attrs1 = AttributeGroup.fromStructField(result1.schema("features"))
     val expectedAttrs1 = new AttributeGroup(
       "features",
       Array[Attribute](
@@ -275,14 +264,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
         new BinaryAttribute(Some("a_bar"), Some(3)),
         new BinaryAttribute(Some("b_zz"), Some(4)),
         new NumericAttribute(Some("c"), Some(5))))
-    assert(attrs1 === expectedAttrs1)
+    // Note the column order is different between R and Spark.
+    val expected1 = Seq(
+      (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0),
+      (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0),
+      (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0),
+      (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0)
+    ).toDF("id", "a", "b", "c", "features", "label")
+
+    testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1, expectedAttrs1)
 
     // There is no impact for string terms interaction.
     val formula2 = new RFormula().setFormula("id ~ a:b + c - 1")
       .setStringIndexerOrderType(StringIndexer.alphabetDesc)
     val model2 = formula2.fit(original)
-    val result2 = model2.transform(original)
-    val resultSchema2 = model2.transformSchema(original.schema)
     // Note the column order is different between R and Spark.
     val expected2 = Seq(
       (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0),
@@ -290,10 +285,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
       (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0),
       (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0)
     ).toDF("id", "a", "b", "c", "features", "label")
-    assert(result2.schema.toString == resultSchema2.toString)
-    testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2)
-
-    val attrs2 = AttributeGroup.fromStructField(result2.schema("features"))
     val expectedAttrs2 = new AttributeGroup(
       "features",
       Array[Attribute](
@@ -304,7 +295,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
         new NumericAttribute(Some("a_bar:b_zz"), Some(5)),
         new NumericAttribute(Some("a_bar:b_zq"), Some(6)),
         new NumericAttribute(Some("c"), Some(7))))
-    assert(attrs2 === expectedAttrs2)
+
+    testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2, expectedAttrs2)
   }
 
   test("index string label") {
@@ -313,13 +305,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
       Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
         .toDF("id", "a", "b")
     val model = formula.fit(original)
+    val attr = NominalAttribute.defaultAttr
     val expected = Seq(
         ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
         ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
         ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0),
         ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)
     ).toDF("id", "a", "b", "features", "label")
-    // assert(result.schema.toString == resultSchema.toString)
+      .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata()))
     testRFormulaTransform[(String, String, Int)](original, model, expected)
   }
 
@@ -329,13 +322,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
       Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5))
     ).toDF("id", "a", "b")
     val model = formula.fit(original)
-    val expected = spark.createDataFrame(
-      Seq(
+    val attr = NominalAttribute.defaultAttr
+    val expected = Seq(
         (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0),
         (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
         (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0),
         (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0))
-    ).toDF("id", "a", "b", "features", "label")
+      .toDF("id", "a", "b", "features", "label")
+      .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata()))
     testRFormulaTransform[(Double, String, Int)](original, model, expected)
   }
 
@@ -344,15 +338,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
       .toDF("id", "a", "b")
     val model = formula.fit(original)
-    val result = model.transform(original)
-    val attrs = AttributeGroup.fromStructField(result.schema("features"))
+    val expected = Seq(
+      (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
+      (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
+      (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
+      (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0))
+      .toDF("id", "a", "b", "features", "label")
     val expectedAttrs = new AttributeGroup(
       "features",
       Array(
         new BinaryAttribute(Some("a_bar"), Some(1)),
         new BinaryAttribute(Some("a_foo"), Some(2)),
         new NumericAttribute(Some("b"), Some(3))))
-    assert(attrs === expectedAttrs)
+    testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs)
+
   }
 
   test("vector attribute generation") {
@@ -360,14 +359,19 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     val original = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
       .toDF("id", "vec")
     val model = formula.fit(original)
-    val result = model.transform(original)
-    val attrs = AttributeGroup.fromStructField(result.schema("features"))
+    val attrs = new AttributeGroup("vec", 2)
+    val expected = Seq(
+      (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0),
+      (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0))
+      .toDF("id", "vec", "features", "label")
+      .select($"id", $"vec".as("vec", attrs.toMetadata()), $"features", $"label")
     val expectedAttrs = new AttributeGroup(
       "features",
       Array[Attribute](
         new NumericAttribute(Some("vec_0"), Some(1)),
         new NumericAttribute(Some("vec_1"), Some(2))))
-    assert(attrs === expectedAttrs)
+
+    testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs)
   }
 
   test("vector attribute generation with unnamed input attrs") {
@@ -381,31 +385,31 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
         NumericAttribute.defaultAttr)).toMetadata()
     val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
     val model = formula.fit(original)
-    val result = model.transform(original)
-    val attrs = AttributeGroup.fromStructField(result.schema("features"))
+    val expected = Seq(
+      (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0),
+      (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0)
+    ).toDF("id", "vec2", "features", "label")
+      .select($"id", $"vec2".as("vec2", metadata), $"features", $"label")
     val expectedAttrs = new AttributeGroup(
       "features",
       Array[Attribute](
         new NumericAttribute(Some("vec2_0"), Some(1)),
         new NumericAttribute(Some("vec2_1"), Some(2))))
-    assert(attrs === expectedAttrs)
+    testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs)
   }
 
   test("numeric interaction") {
     val formula = new RFormula().setFormula("a ~ b:c:d")
     val original = Seq((1, 2, 4, 2), (2, 3, 4, 1)).toDF("a", "b", "c", "d")
     val model = formula.fit(original)
-    val result = model.transform(original)
     val expected = Seq(
       (1, 2, 4, 2, Vectors.dense(16.0), 1.0),
       (2, 3, 4, 1, Vectors.dense(12.0), 2.0)
     ).toDF("a", "b", "c", "d", "features", "label")
-    testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected)
-    val attrs = AttributeGroup.fromStructField(result.schema("features"))
     val expectedAttrs = new AttributeGroup(
       "features",
       Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1))))
-    assert(attrs === expectedAttrs)
+    testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected, expectedAttrs)
   }
 
   test("factor numeric interaction") {
@@ -414,7 +418,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
       Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5))
         .toDF("id", "a", "b")
     val model = formula.fit(original)
-    val result = model.transform(original)
     val expected = Seq(
       (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
       (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0),
@@ -423,15 +426,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
       (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
       (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)
     ).toDF("id", "a", "b", "features", "label")
-    testRFormulaTransform[(Int, String, Int)](original, model, expected)
-    val attrs = AttributeGroup.fromStructField(result.schema("features"))
     val expectedAttrs = new AttributeGroup(
       "features",
       Array[Attribute](
         new NumericAttribute(Some("a_baz:b"), Some(1)),
         new NumericAttribute(Some("a_bar:b"), Some(2)),
         new NumericAttribute(Some("a_foo:b"), Some(3))))
-    assert(attrs === expectedAttrs)
+    testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs)
   }
 
   test("factor factor interaction") {
@@ -439,14 +440,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
     val original =
       Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b")
     val model = formula.fit(original)
-    val result = model.transform(original)
     val expected = Seq(
       (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0),
       (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
       (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0)
     ).toDF("id", "a", "b", "features", "label")
     testRFormulaTransform[(Int, String, String)](original, model, expected)
-    val attrs = AttributeGroup.fromStructField(result.schema("features"))
     val expectedAttrs = new AttributeGroup(
       "features",
       Array[Attribute](
@@ -454,7 +453,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
         new NumericAttribute(Some("a_bar:b_zz"), Some(2)),
         new NumericAttribute(Some("a_foo:b_zq"), Some(3)),
         new NumericAttribute(Some("a_foo:b_zz"), Some(4))))
-    assert(attrs === expectedAttrs)
+    testRFormulaTransform[(Int, String, String)](original, model, expected, expectedAttrs)
   }
 
   test("read/write: RFormula") {
@@ -517,9 +516,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
 
     // Handle unseen features.
     val formula1 = new RFormula().setFormula("id ~ a + b")
-    intercept[SparkException] {
-      formula1.fit(df1).transform(df2).collect()
-    }
+    testTransformerByInterceptingException[(Int, String, String)](
+      df2,
+      formula1.fit(df1),
+      "Unseen label:",
+      "features")
     val model1 = formula1.setHandleInvalid("skip").fit(df1)
     val model2 = formula1.setHandleInvalid("keep").fit(df1)
 
@@ -538,21 +539,28 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
 
     // Handle unseen labels.
     val formula2 = new RFormula().setFormula("b ~ a + id")
-    intercept[SparkException] {
-      formula2.fit(df1).transform(df2).collect()
-    }
+    testTransformerByInterceptingException[(Int, String, String)](
+      df2,
+      formula2.fit(df1),
+      "Unseen label:",
+      "label")
+
     val model3 = formula2.setHandleInvalid("skip").fit(df1)
     val model4 = formula2.setHandleInvalid("keep").fit(df1)
 
+    val attr = NominalAttribute.defaultAttr
     val expected3 = Seq(
       (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0),
       (2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0)
     ).toDF("id", "a", "b", "features", "label")
+      .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata()))
+
     val expected4 = Seq(
       (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0),
       (2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0),
       (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0)
     ).toDF("id", "a", "b", "features", "label")
+      .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata()))
 
     testRFormulaTransform[(Int, String, String)](df2, model3, expected3)
     testRFormulaTransform[(Int, String, String)](df2, model4, expected4)

http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index 673a146..cf09418 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -17,15 +17,12 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
 import org.apache.spark.sql.types.{LongType, StructField, StructType}
 import org.apache.spark.storage.StorageLevel
 
-class SQLTransformerSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class SQLTransformerSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -37,14 +34,22 @@ class SQLTransformerSuite
     val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
     val sqlTrans = new SQLTransformer().setStatement(
       "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
-    val result = sqlTrans.transform(original)
-    val resultSchema = sqlTrans.transformSchema(original.schema)
-    val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))
+     val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))
       .toDF("id", "v1", "v2", "v3", "v4")
-    assert(result.schema.toString == resultSchema.toString)
-    assert(resultSchema == expected.schema)
-    assert(result.collect().toSeq == expected.collect().toSeq)
-    assert(original.sparkSession.catalog.listTables().count() == 0)
+    val resultSchema = sqlTrans.transformSchema(original.schema)
+    testTransformerByGlobalCheckFunc[(Int, Double, Double)](
+      original,
+      sqlTrans,
+      "id",
+      "v1",
+      "v2",
+      "v3",
+      "v4") { rows =>
+      assert(rows.head.schema.toString == resultSchema.toString)
+      assert(resultSchema == expected.schema)
+      assert(rows == expected.collect().toSeq)
+      assert(original.sparkSession.catalog.listTables().count() == 0)
+    }
   }
 
   test("read/write") {
@@ -63,13 +68,13 @@ class SQLTransformerSuite
   }
 
   test("SPARK-22538: SQLTransformer should not unpersist given dataset") {
-    val df = spark.range(10)
+    val df = spark.range(10).toDF()
     df.cache()
     df.count()
     assert(df.storageLevel != StorageLevel.NONE)
-    new SQLTransformer()
+    val sqlTrans = new SQLTransformer()
       .setStatement("SELECT id + 1 AS id1 FROM __THIS__")
-      .transform(df)
+    testTransformerByGlobalCheckFunc[Long](df, sqlTrans, "id1") { _ => }
     assert(df.storageLevel != StorageLevel.NONE)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
index 350ba44..c5c49d6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -17,16 +17,13 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
 
-class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
-  with DefaultReadWriteTest {
+class StandardScalerSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -60,12 +57,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
     )
   }
 
-  def assertResult(df: DataFrame): Unit = {
-    df.select("standardized_features", "expected").collect().foreach {
-      case Row(vector1: Vector, vector2: Vector) =>
-        assert(vector1 ~== vector2 absTol 1E-5,
-          "The vector value is not correct after standardization.")
-    }
+  def assertResult: Row => Unit = {
+    case Row(vector1: Vector, vector2: Vector) =>
+      assert(vector1 ~== vector2 absTol 1E-5,
+        "The vector value is not correct after standardization.")
   }
 
   test("params") {
@@ -83,7 +78,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
     val standardScaler0 = standardScalerEst0.fit(df0)
     MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0)
 
-    assertResult(standardScaler0.transform(df0))
+    testTransformer[(Vector, Vector)](df0, standardScaler0, "standardized_features", "expected")(
+      assertResult)
   }
 
   test("Standardization with setter") {
@@ -112,9 +108,12 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
       .setWithStd(false)
       .fit(df3)
 
-    assertResult(standardScaler1.transform(df1))
-    assertResult(standardScaler2.transform(df2))
-    assertResult(standardScaler3.transform(df3))
+    testTransformer[(Vector, Vector)](df1, standardScaler1, "standardized_features", "expected")(
+      assertResult)
+    testTransformer[(Vector, Vector)](df2, standardScaler2, "standardized_features", "expected")(
+      assertResult)
+    testTransformer[(Vector, Vector)](df3, standardScaler3, "standardized_features", "expected")(
+      assertResult)
   }
 
   test("sparse data and withMean") {
@@ -130,7 +129,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
       .setWithMean(true)
       .setWithStd(false)
       .fit(df)
-    assertResult(standardScaler.transform(df))
+    testTransformer[(Vector, Vector)](df, standardScaler, "standardized_features", "expected")(
+      assertResult)
   }
 
   test("StandardScaler read/write") {
@@ -149,4 +149,5 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
     assert(newInstance.std === instance.std)
     assert(newInstance.mean === instance.mean)
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/279b3db8/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index 5262b14..21259a5 100755
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -17,28 +17,20 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Dataset, Row}
-
-object StopWordsRemoverSuite extends SparkFunSuite {
-  def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
-    t.transform(dataset)
-      .select("filtered", "expected")
-      .collect()
-      .foreach { case Row(tokens, wantedTokens) =>
-        assert(tokens === wantedTokens)
-    }
-  }
-}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.{DataFrame, Row}
 
-class StopWordsRemoverSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
 
-  import StopWordsRemoverSuite._
   import testImplicits._
 
+  def testStopWordsRemover(t: StopWordsRemover, dataFrame: DataFrame): Unit = {
+    testTransformer[(Array[String], Array[String])](dataFrame, t, "filtered", "expected") {
+       case Row(tokens: Seq[_], wantedTokens: Seq[_]) =>
+         assert(tokens === wantedTokens)
+    }
+  }
+
   test("StopWordsRemover default") {
     val remover = new StopWordsRemover()
       .setInputCol("raw")
@@ -151,9 +143,10 @@ class StopWordsRemoverSuite
       .setOutputCol(outputCol)
     val dataSet = Seq((Seq("The", "the", "swift"), Seq("swift"))).toDF("raw", outputCol)
 
-    val thrown = intercept[IllegalArgumentException] {
-      testStopWordsRemover(remover, dataSet)
-    }
-    assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.")
+    testTransformerByInterceptingException[(Array[String], Array[String])](
+      dataSet,
+      remover,
+      s"requirement failed: Column $outputCol already exists.",
+      "expected")
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org