You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/04/11 18:28:33 UTC

[2/2] spark git commit: [SPARK-14500] [ML] Accept Dataset[_] instead of DataFrame in MLlib APIs

[SPARK-14500] [ML] Accept Dataset[_] instead of DataFrame in MLlib APIs

## What changes were proposed in this pull request?

This PR updates MLlib APIs to accept `Dataset[_]` as input where `DataFrame` was the input type. This PR doesn't change the output type. In Java, `Dataset[_]` maps to `Dataset<?>`, which includes `Dataset<Row>`. Some implementations were changed in order to return `DataFrame`. Tests and examples were updated. Note that this is a breaking change for subclasses of Transformer/Estimator.

Lol, we don't have to rename the input argument, which has been `dataset` since Spark 1.2.

TODOs:
- [x] update MiMaExcludes (seems all covered by explicit filters from SPARK-13920)
- [x] Python
- [x] add a new test to accept Dataset[LabeledPoint]
- [x] remove unused imports of Dataset

## How was this patch tested?

Exiting unit tests with some modifications.

cc: rxin jkbradley

Author: Xiangrui Meng <me...@databricks.com>

Closes #12274 from mengxr/SPARK-14500.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c751fcf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c751fcf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c751fcf

Branch: refs/heads/master
Commit: 1c751fcf488189e5176546fe0d00f560ffcf1cec
Parents: e82d95b
Author: Xiangrui Meng <me...@databricks.com>
Authored: Mon Apr 11 09:28:28 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Apr 11 09:28:28 2016 -0700

----------------------------------------------------------------------
 .../examples/ml/JavaDeveloperApiExample.java    |  2 +-
 .../spark/examples/ml/DeveloperApiExample.scala |  4 ++--
 .../scala/org/apache/spark/ml/Estimator.scala   | 16 ++++++++-----
 .../scala/org/apache/spark/ml/Pipeline.scala    | 12 +++++-----
 .../scala/org/apache/spark/ml/Predictor.scala   | 14 ++++++------
 .../scala/org/apache/spark/ml/Transformer.scala | 15 +++++++-----
 .../spark/ml/classification/Classifier.scala    |  6 ++---
 .../classification/DecisionTreeClassifier.scala |  4 ++--
 .../spark/ml/classification/GBTClassifier.scala |  6 ++---
 .../ml/classification/LogisticRegression.scala  |  8 +++----
 .../MultilayerPerceptronClassifier.scala        |  4 ++--
 .../spark/ml/classification/NaiveBayes.scala    |  4 ++--
 .../spark/ml/classification/OneVsRest.scala     | 10 ++++----
 .../ProbabilisticClassifier.scala               |  6 ++---
 .../classification/RandomForestClassifier.scala |  6 ++---
 .../spark/ml/clustering/BisectingKMeans.scala   |  8 +++----
 .../spark/ml/clustering/GaussianMixture.scala   |  6 ++---
 .../org/apache/spark/ml/clustering/KMeans.scala | 14 ++++++------
 .../org/apache/spark/ml/clustering/LDA.scala    | 24 ++++++++++----------
 .../BinaryClassificationEvaluator.scala         |  6 ++---
 .../apache/spark/ml/evaluation/Evaluator.scala  | 10 ++++----
 .../MulticlassClassificationEvaluator.scala     |  6 ++---
 .../ml/evaluation/RegressionEvaluator.scala     |  6 ++---
 .../org/apache/spark/ml/feature/Binarizer.scala |  3 ++-
 .../apache/spark/ml/feature/Bucketizer.scala    |  3 ++-
 .../apache/spark/ml/feature/ChiSqSelector.scala |  6 +++--
 .../spark/ml/feature/CountVectorizer.scala      |  8 ++++---
 .../org/apache/spark/ml/feature/HashingTF.scala |  5 ++--
 .../scala/org/apache/spark/ml/feature/IDF.scala |  6 +++--
 .../apache/spark/ml/feature/Interaction.scala   |  6 ++---
 .../apache/spark/ml/feature/MaxAbsScaler.scala  |  6 +++--
 .../apache/spark/ml/feature/MinMaxScaler.scala  |  6 +++--
 .../apache/spark/ml/feature/OneHotEncoder.scala |  5 ++--
 .../scala/org/apache/spark/ml/feature/PCA.scala |  6 +++--
 .../spark/ml/feature/QuantileDiscretizer.scala  | 13 +++++++----
 .../org/apache/spark/ml/feature/RFormula.scala  | 18 ++++++++-------
 .../spark/ml/feature/SQLTransformer.scala       |  6 ++---
 .../spark/ml/feature/StandardScaler.scala       |  6 +++--
 .../spark/ml/feature/StopWordsRemover.scala     |  5 ++--
 .../apache/spark/ml/feature/StringIndexer.scala | 13 +++++++----
 .../spark/ml/feature/VectorAssembler.scala      |  7 +++---
 .../apache/spark/ml/feature/VectorIndexer.scala |  8 ++++---
 .../apache/spark/ml/feature/VectorSlicer.scala  |  5 ++--
 .../org/apache/spark/ml/feature/Word2Vec.scala  |  8 ++++---
 .../ml/r/AFTSurvivalRegressionWrapper.scala     |  4 ++--
 .../org/apache/spark/ml/r/KMeansWrapper.scala   |  4 ++--
 .../apache/spark/ml/r/NaiveBayesWrapper.scala   |  4 ++--
 .../apache/spark/ml/recommendation/ALS.scala    | 10 ++++----
 .../ml/regression/AFTSurvivalRegression.scala   | 12 +++++-----
 .../ml/regression/DecisionTreeRegressor.scala   | 11 +++++----
 .../spark/ml/regression/GBTRegressor.scala      |  6 ++---
 .../GeneralizedLinearRegression.scala           |  4 ++--
 .../ml/regression/IsotonicRegression.scala      | 12 +++++-----
 .../spark/ml/regression/LinearRegression.scala  |  6 ++---
 .../ml/regression/RandomForestRegressor.scala   |  6 ++---
 .../apache/spark/ml/tuning/CrossValidator.scala | 12 +++++-----
 .../spark/ml/tuning/TrainValidationSplit.scala  | 11 +++++----
 .../org/apache/spark/mllib/linalg/Vectors.scala |  2 +-
 .../org/apache/spark/ml/PipelineSuite.scala     | 12 +++++++---
 .../LogisticRegressionSuite.scala               |  4 ++--
 .../MultilayerPerceptronClassifierSuite.scala   |  4 ++--
 .../ml/classification/NaiveBayesSuite.scala     |  4 ++--
 .../ml/classification/OneVsRestSuite.scala      |  6 ++---
 .../ml/clustering/BisectingKMeansSuite.scala    |  4 ++--
 .../ml/clustering/GaussianMixtureSuite.scala    |  4 ++--
 .../spark/ml/clustering/KMeansSuite.scala       |  4 ++--
 .../apache/spark/ml/clustering/LDASuite.scala   |  4 ++--
 .../apache/spark/ml/feature/NGramSuite.scala    |  4 ++--
 .../ml/feature/StopWordsRemoverSuite.scala      |  4 ++--
 .../spark/ml/feature/StringIndexerSuite.scala   |  2 +-
 .../spark/ml/feature/TokenizerSuite.scala       |  4 ++--
 .../GeneralizedLinearRegressionSuite.scala      |  8 +++++++
 .../spark/ml/tuning/CrossValidatorSuite.scala   |  8 +++----
 .../ml/tuning/TrainValidationSplitSuite.scala   |  6 ++---
 .../spark/ml/util/DefaultReadWriteTest.scala    |  4 ++--
 75 files changed, 296 insertions(+), 240 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index fbd8817..0ba9478 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -146,7 +146,7 @@ class MyJavaLogisticRegression
 
   // This method is used by fit().
   // In Java, we have to make it public since Java does not understand Scala's protected modifier.
-  public MyJavaLogisticRegressionModel train(Dataset<Row> dataset) {
+  public MyJavaLogisticRegressionModel train(Dataset<?> dataset) {
     // Extract columns from data using helper method.
     JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index c1f63c6..8d127f9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap}
 import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
 
 /**
  * A simple example demonstrating how to write your own learning algorithm using Estimator,
@@ -120,7 +120,7 @@ private class MyLogisticRegression(override val uid: String)
   def setMaxIter(value: Int): this.type = set(maxIter, value)
 
   // This method is used by fit()
-  override protected def train(dataset: DataFrame): MyLogisticRegressionModel = {
+  override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = {
     // Extract columns from data using helper method.
     val oldDataset = extractLabeledPoints(dataset)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 57e4165..1247882 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -19,9 +19,9 @@ package org.apache.spark.ml
 
 import scala.annotation.varargs
 
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.ml.param.{ParamMap, ParamPair}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dataset
 
 /**
  * :: DeveloperApi ::
@@ -39,8 +39,9 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
    *                        Estimator's embedded ParamMap.
    * @return fitted model
    */
+  @Since("2.0.0")
   @varargs
-  def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
+  def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
     val map = new ParamMap()
       .put(firstParamPair)
       .put(otherParamPairs: _*)
@@ -55,14 +56,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
    *                 These values override any specified in this Estimator's embedded ParamMap.
    * @return fitted model
    */
-  def fit(dataset: DataFrame, paramMap: ParamMap): M = {
+  @Since("2.0.0")
+  def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
     copy(paramMap).fit(dataset)
   }
 
   /**
    * Fits a model to the input data.
    */
-  def fit(dataset: DataFrame): M
+  @Since("2.0.0")
+  def fit(dataset: Dataset[_]): M
 
   /**
    * Fits multiple models to the input data with multiple sets of parameters.
@@ -74,7 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
    *                  These values override any specified in this Estimator's embedded ParamMap.
    * @return fitted models, matching the input parameter maps
    */
-  def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
+  @Since("2.0.0")
+  def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
     paramMaps.map(fit(dataset, _))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index afefaaa..8206672 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -31,7 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.param.{Param, ParamMap, Params}
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -123,8 +123,8 @@ class Pipeline @Since("1.4.0") (
    * @param dataset input dataset
    * @return fitted pipeline
    */
-  @Since("1.2.0")
-  override def fit(dataset: DataFrame): PipelineModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): PipelineModel = {
     transformSchema(dataset.schema, logging = true)
     val theStages = $(stages)
     // Search for the last estimator.
@@ -291,10 +291,10 @@ class PipelineModel private[ml] (
     this(uid, stages.asScala.toArray)
   }
 
-  @Since("1.2.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
-    stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur))
+    stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur))
   }
 
   @Since("1.2.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index d23ae6f..81140d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 
@@ -83,7 +83,7 @@ abstract class Predictor[
   /** @group setParam */
   def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
 
-  override def fit(dataset: DataFrame): M = {
+  override def fit(dataset: Dataset[_]): M = {
     // This handles a few items such as schema validation.
     // Developers only need to implement train().
     transformSchema(dataset.schema, logging = true)
@@ -100,7 +100,7 @@ abstract class Predictor[
    * @param dataset  Training dataset
    * @return  Fitted model
    */
-  protected def train(dataset: DataFrame): M
+  protected def train(dataset: Dataset[_]): M
 
   /**
    * Returns the SQL DataType corresponding to the FeaturesType type parameter.
@@ -120,7 +120,7 @@ abstract class Predictor[
    * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
    * and put it in an RDD with strong types.
    */
-  protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
+  protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
     dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
       case Row(label: Double, features: Vector) => LabeledPoint(label, features)
     }
@@ -171,18 +171,18 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
    * @param dataset input dataset
    * @return transformed dataset with [[predictionCol]] of type [[Double]]
    */
-  override def transform(dataset: DataFrame): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     if ($(predictionCol).nonEmpty) {
       transformImpl(dataset)
     } else {
       this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
         " since no output columns were set.")
-      dataset
+      dataset.toDF
     }
   }
 
-  protected def transformImpl(dataset: DataFrame): DataFrame = {
+  protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val predictUDF = udf { (features: Any) =>
       predict(features.asInstanceOf[FeaturesType])
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 2538c0f..a3a2b55 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -19,11 +19,11 @@ package org.apache.spark.ml
 
 import scala.annotation.varargs
 
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
@@ -41,9 +41,10 @@ abstract class Transformer extends PipelineStage {
    * @param otherParamPairs other param pairs, overwrite embedded params
    * @return transformed dataset
    */
+  @Since("2.0.0")
   @varargs
   def transform(
-      dataset: DataFrame,
+      dataset: Dataset[_],
       firstParamPair: ParamPair[_],
       otherParamPairs: ParamPair[_]*): DataFrame = {
     val map = new ParamMap()
@@ -58,14 +59,16 @@ abstract class Transformer extends PipelineStage {
    * @param paramMap additional parameters, overwrite embedded params
    * @return transformed dataset
    */
-  def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+  @Since("2.0.0")
+  def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = {
     this.copy(paramMap).transform(dataset)
   }
 
   /**
    * Transforms the input dataset.
    */
-  def transform(dataset: DataFrame): DataFrame
+  @Since("2.0.0")
+  def transform(dataset: Dataset[_]): DataFrame
 
   override def copy(extra: ParamMap): Transformer
 }
@@ -113,7 +116,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
     StructType(outputFields)
   }
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val transformUDF = udf(this.createTransformFunc, outputDataType)
     dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 8186afc..473e801 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
 import org.apache.spark.ml.param.shared.HasRawPredictionCol
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DataType, StructType}
 
@@ -92,7 +92,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
    * @param dataset input dataset
    * @return transformed dataset
    */
-  override def transform(dataset: DataFrame): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
 
     // Output selected columns only.
@@ -123,7 +123,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
       logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
         " since no output columns were set.")
     }
-    outputData
+    outputData.toDF
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 4525bf7..300ae43 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 
 
 /**
@@ -82,7 +82,7 @@ final class DecisionTreeClassifier @Since("1.4.0") (
   @Since("1.6.0")
   override def setSeed(value: Long): this.type = super.setSeed(value)
 
-  override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
+  override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index a2150fb..46e8b89 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
 import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 
 /**
@@ -149,7 +149,7 @@ final class GBTClassifier @Since("1.4.0") (
     }
   }
 
-  override protected def train(dataset: DataFrame): GBTClassificationModel = {
+  override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
@@ -220,7 +220,7 @@ final class GBTClassificationModel private[ml](
   @Since("1.4.0")
   override def treeWeights: Array[Double] = _treeWeights
 
-  override protected def transformImpl(dataset: DataFrame): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
     val predictUDF = udf { (features: Any) =>
       bcastModel.value.predict(features.asInstanceOf[Vector])

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 268c3e3..4a3fe5c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -36,7 +36,7 @@ import org.apache.spark.mllib.linalg.BLAS._
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, lit}
 import org.apache.spark.sql.types.DoubleType
 import org.apache.spark.storage.StorageLevel
@@ -257,12 +257,12 @@ class LogisticRegression @Since("1.2.0") (
     this
   }
 
-  override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
+  override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
     val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
     train(dataset, handlePersistence)
   }
 
-  protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean):
+  protected[spark] def train(dataset: Dataset[_], handlePersistence: Boolean):
       LogisticRegressionModel = {
     val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
     val instances: RDD[Instance] =
@@ -544,7 +544,7 @@ class LogisticRegressionModel private[spark] (
    * @param dataset Test dataset to evaluate model on.
    */
   @Since("2.0.0")
-  def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
+  def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = {
     // Handle possible missing or invalid prediction columns
     val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
     new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 79bb2a8..9ff5252 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTo
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 
 /** Params for Multilayer Perceptron. */
 private[ml] trait MultilayerPerceptronParams extends PredictorParams
@@ -199,7 +199,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
    * @param dataset Training dataset
    * @return Fitted model
    */
-  override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = {
+  override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = {
     val myLayers = $(layers)
     val labels = myLayers.last
     val lpData = extractLabeledPoints(dataset)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 483ef0d..267d63b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesMo
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 
 /**
  * Params for Naive Bayes Classifiers.
@@ -101,7 +101,7 @@ class NaiveBayes @Since("1.5.0") (
   def setModelType(value: String): this.type = set(modelType, value)
   setDefault(modelType -> OldNaiveBayes.Multinomial)
 
-  override protected def train(dataset: DataFrame): NaiveBayesModel = {
+  override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
     val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
     NaiveBayesModel.fromOld(oldModel, this)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 263d54c..4de1b87 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -33,7 +33,7 @@ import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
@@ -140,8 +140,8 @@ final class OneVsRestModel private[ml] (
     validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
   }
 
-  @Since("1.4.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     // Check schema
     transformSchema(dataset.schema, logging = true)
 
@@ -293,8 +293,8 @@ final class OneVsRest @Since("1.4.0") (
     validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
   }
 
-  @Since("1.4.0")
-  override def fit(dataset: DataFrame): OneVsRestModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): OneVsRestModel = {
     transformSchema(dataset.schema)
 
     // determine number of classes either from metadata if provided, or via computation.

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 865614a..d00fee1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DataType, StructType}
 
@@ -95,7 +95,7 @@ abstract class ProbabilisticClassificationModel[
    * @param dataset input dataset
    * @return transformed dataset
    */
-  override def transform(dataset: DataFrame): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     if (isDefined(thresholds)) {
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -145,7 +145,7 @@ abstract class ProbabilisticClassificationModel[
       this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
         " since no output columns were set.")
     }
-    outputData
+    outputData.toDF
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index cb42532..9d80b8e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -31,7 +31,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 
 
@@ -98,7 +98,7 @@ final class RandomForestClassifier @Since("1.4.0") (
   override def setFeatureSubsetStrategy(value: String): this.type =
     super.setFeatureSubsetStrategy(value)
 
-  override protected def train(dataset: DataFrame): RandomForestClassificationModel = {
+  override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
@@ -180,7 +180,7 @@ final class RandomForestClassificationModel private[ml] (
   @Since("1.4.0")
   override def treeWeights: Array[Double] = _treeWeights
 
-  override protected def transformImpl(dataset: DataFrame): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
     val predictUDF = udf { (features: Any) =>
       bcastModel.value.predict(features.asInstanceOf[Vector])

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 55f751c..6cc9117 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.util._
 import org.apache.spark.mllib.clustering.
   {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{IntegerType, StructType}
 
@@ -92,7 +92,7 @@ class BisectingKMeansModel private[ml] (
   }
 
   @Since("2.0.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val predictUDF = udf((vector: Vector) => predict(vector))
     dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
   }
@@ -112,7 +112,7 @@ class BisectingKMeansModel private[ml] (
    * centers.
    */
   @Since("2.0.0")
-  def computeCost(dataset: DataFrame): Double = {
+  def computeCost(dataset: Dataset[_]): Double = {
     SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
     val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
     parentModel.computeCost(data)
@@ -215,7 +215,7 @@ class BisectingKMeans @Since("2.0.0") (
   def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)
 
   @Since("2.0.0")
-  override def fit(dataset: DataFrame): BisectingKMeansModel = {
+  override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
     val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
 
     val bkm = new MLlibBisectingKMeans()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 120bf3c..ead8ad7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.util._
 import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{IntegerType, StructType}
 
@@ -80,7 +80,7 @@ class GaussianMixtureModel private[ml] (
   }
 
   @Since("2.0.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val predUDF = udf((vector: Vector) => predict(vector))
     val probUDF = udf((vector: Vector) => predictProbability(vector))
     dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
@@ -238,7 +238,7 @@ class GaussianMixture @Since("2.0.0") (
   def setSeed(value: Long): this.type = set(seed, value)
 
   @Since("2.0.0")
-  override def fit(dataset: DataFrame): GaussianMixtureModel = {
+  override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
     val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
 
     val algo = new MLlibGM()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index a8beef8..d716bc6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{IntegerType, StructType}
 
@@ -105,8 +105,8 @@ class KMeansModel private[ml] (
     copyValues(copied, extra)
   }
 
-  @Since("1.5.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val predictUDF = udf((vector: Vector) => predict(vector))
     dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
   }
@@ -126,8 +126,8 @@ class KMeansModel private[ml] (
    * model on the given data.
    */
   // TODO: Replace the temp fix when we have proper evaluators defined for clustering.
-  @Since("1.6.0")
-  def computeCost(dataset: DataFrame): Double = {
+  @Since("2.0.0")
+  def computeCost(dataset: Dataset[_]): Double = {
     SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
     val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
     parentModel.computeCost(data)
@@ -254,8 +254,8 @@ class KMeans @Since("1.5.0") (
   @Since("1.5.0")
   def setSeed(value: Long): this.type = set(seed, value)
 
-  @Since("1.5.0")
-  override def fit(dataset: DataFrame): KMeansModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): KMeansModel = {
     val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
 
     val algo = new MLlibKMeans()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 89a7a4c..c57ceba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL
 import org.apache.spark.mllib.impl.PeriodicCheckpointer
 import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
 import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf}
 import org.apache.spark.sql.types.StructType
 
@@ -402,15 +402,15 @@ sealed abstract class LDAModel private[ml] (
    *          is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
    *          This implementation may be changed in the future.
    */
-  @Since("1.6.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     if ($(topicDistributionCol).nonEmpty) {
       val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
-      dataset.withColumn($(topicDistributionCol), t(col($(featuresCol))))
+      dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
     } else {
       logWarning("LDAModel.transform was called without any output columns. Set an output column" +
         " such as topicDistributionCol to produce results.")
-      dataset
+      dataset.toDF
     }
   }
 
@@ -455,8 +455,8 @@ sealed abstract class LDAModel private[ml] (
    * @param dataset  test corpus to use for calculating log likelihood
    * @return variational lower bound on the log likelihood of the entire corpus
    */
-  @Since("1.6.0")
-  def logLikelihood(dataset: DataFrame): Double = {
+  @Since("2.0.0")
+  def logLikelihood(dataset: Dataset[_]): Double = {
     val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
     oldLocalModel.logLikelihood(oldDataset)
   }
@@ -472,8 +472,8 @@ sealed abstract class LDAModel private[ml] (
    * @param dataset test corpus to use for calculating perplexity
    * @return Variational upper bound on log perplexity per token.
    */
-  @Since("1.6.0")
-  def logPerplexity(dataset: DataFrame): Double = {
+  @Since("2.0.0")
+  def logPerplexity(dataset: Dataset[_]): Double = {
     val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
     oldLocalModel.logPerplexity(oldDataset)
   }
@@ -840,8 +840,8 @@ class LDA @Since("1.6.0") (
   @Since("1.6.0")
   override def copy(extra: ParamMap): LDA = defaultCopy(extra)
 
-  @Since("1.6.0")
-  override def fit(dataset: DataFrame): LDAModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): LDAModel = {
     transformSchema(dataset.schema, logging = true)
     val oldLDA = new OldLDA()
       .setK($(k))
@@ -873,7 +873,7 @@ class LDA @Since("1.6.0") (
 private[clustering] object LDA extends DefaultParamsReadable[LDA] {
 
   /** Get dataset for spark.mllib LDA */
-  def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = {
+  def getOldDataset(dataset: Dataset[_], featuresCol: String): RDD[(Long, Vector)] = {
     dataset
       .withColumn("docId", monotonicallyIncreasingId())
       .select("docId", featuresCol)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 337ffbe..bde8c27 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.types.DoubleType
 
 /**
@@ -69,8 +69,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
 
   setDefault(metricName -> "areaUnderROC")
 
-  @Since("1.2.0")
-  override def evaluate(dataset: DataFrame): Double = {
+  @Since("2.0.0")
+  override def evaluate(dataset: Dataset[_]): Double = {
     val schema = dataset.schema
     SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
     SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
index 0f22cca..5f765c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
 
 import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.ml.param.{ParamMap, Params}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dataset
 
 /**
  * :: DeveloperApi ::
@@ -36,8 +36,8 @@ abstract class Evaluator extends Params {
    * @param paramMap parameter map that specifies the input columns and output metrics
    * @return metric
    */
-  @Since("1.5.0")
-  def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
+  @Since("2.0.0")
+  def evaluate(dataset: Dataset[_], paramMap: ParamMap): Double = {
     this.copy(paramMap).evaluate(dataset)
   }
 
@@ -46,8 +46,8 @@ abstract class Evaluator extends Params {
    * @param dataset a dataset that contains labels/observations and predictions.
    * @return metric
    */
-  @Since("1.5.0")
-  def evaluate(dataset: DataFrame): Double
+  @Since("2.0.0")
+  def evaluate(dataset: Dataset[_]): Double
 
   /**
    * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 55ff443..3acfc22 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
 import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.types.DoubleType
 
 /**
@@ -68,8 +68,8 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
 
   setDefault(metricName -> "f1")
 
-  @Since("1.5.0")
-  override def evaluate(dataset: DataFrame): Double = {
+  @Since("2.0.0")
+  override def evaluate(dataset: Dataset[_]): Double = {
     val schema = dataset.schema
     SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
     SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 9976d7e..4134e2d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
 import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
 import org.apache.spark.mllib.evaluation.RegressionMetrics
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, FloatType}
 
@@ -70,8 +70,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
 
   setDefault(metricName -> "rmse")
 
-  @Since("1.4.0")
-  override def evaluate(dataset: DataFrame): Double = {
+  @Since("2.0.0")
+  override def evaluate(dataset: Dataset[_]): Double = {
     val schema = dataset.schema
     val predictionColName = $(predictionCol)
     val predictionType = schema($(predictionCol)).dataType

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 2f8e3a0..898ac2c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -64,7 +64,8 @@ final class Binarizer(override val uid: String)
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val outputSchema = transformSchema(dataset.schema, logging = true)
     val schema = dataset.schema
     val inputType = schema($(inputCol)).dataType

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 33abc7c..10e622a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -68,7 +68,8 @@ final class Bucketizer(override val uid: String)
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema)
     val bucketizer = udf { feature: Double =>
       Bucketizer.binarySearchForBuckets($(splits), feature)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index b9e9d56..cfecae7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -77,7 +77,8 @@ final class ChiSqSelector(override val uid: String)
   /** @group setParam */
   def setLabelCol(value: String): this.type = set(labelCol, value)
 
-  override def fit(dataset: DataFrame): ChiSqSelectorModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
     transformSchema(dataset.schema, logging = true)
     val input = dataset.select($(labelCol), $(featuresCol)).rdd.map {
       case Row(label: Double, features: Vector) =>
@@ -127,7 +128,8 @@ final class ChiSqSelectorModel private[ml] (
   /** @group setParam */
   def setLabelCol(value: String): this.type = set(labelCol, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val transformedSchema = transformSchema(dataset.schema, logging = true)
     val newField = transformedSchema.last
     val selector = udf { chiSqSelector.transform _ }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 00abbbe..922670a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{Vectors, VectorUDT}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashMap
@@ -147,7 +147,8 @@ class CountVectorizer(override val uid: String)
 
   setDefault(vocabSize -> (1 << 18), minDF -> 1)
 
-  override def fit(dataset: DataFrame): CountVectorizerModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): CountVectorizerModel = {
     transformSchema(dataset.schema, logging = true)
     val vocSize = $(vocabSize)
     val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
@@ -224,7 +225,8 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
   /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
   private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     if (broadcastDict.isEmpty) {
       val dict = vocabulary.zipWithIndex.toMap

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 0f7ae5a..467ad73 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidat
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{ArrayType, StructType}
 
@@ -77,7 +77,8 @@ class HashingTF(override val uid: String)
   /** @group setParam */
   def setBinary(value: Boolean): this.type = set(binary, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val outputSchema = transformSchema(dataset.schema)
     val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
     val t = udf { terms: Seq[_] => hashingTF.transform(terms) }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index f36cf50..5075b78 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -76,7 +76,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
   /** @group setParam */
   def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
 
-  override def fit(dataset: DataFrame): IDFModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): IDFModel = {
     transformSchema(dataset.schema, logging = true)
     val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
     val idf = new feature.IDF($(minDocFreq)).fit(input)
@@ -115,7 +116,8 @@ class IDFModel private[ml] (
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val idf = udf { vec: Vector => idfModel.transform(vec) }
     dataset.withColumn($(outputCol), idf(col($(inputCol))))

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index d3fe6e5..9ca34e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.ml.Transformer
 import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
@@ -68,8 +68,8 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
     StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
   }
 
-  @Since("1.6.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val inputFeatures = $(inputCols).map(c => dataset.schema(c))
     val featureEncoders = getFeatureEncoders(inputFeatures)
     val featureAttrs = getFeatureAttrs(inputFeatures)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 7de5a4d..e9df600 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -66,7 +66,8 @@ class MaxAbsScaler @Since("2.0.0") (override val uid: String)
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def fit(dataset: DataFrame): MaxAbsScalerModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): MaxAbsScalerModel = {
     transformSchema(dataset.schema, logging = true)
     val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
     val summary = Statistics.colStats(input)
@@ -111,7 +112,8 @@ class MaxAbsScalerModel private[ml] (
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     // TODO: this looks hack, we may have to handle sparse and dense vectors separately.
     val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x))

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index b13684a..125becb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -103,7 +103,8 @@ class MinMaxScaler(override val uid: String)
   /** @group setParam */
   def setMax(value: Double): this.type = set(max, value)
 
-  override def fit(dataset: DataFrame): MinMaxScalerModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): MinMaxScalerModel = {
     transformSchema(dataset.schema, logging = true)
     val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
     val summary = Statistics.colStats(input)
@@ -154,7 +155,8 @@ class MinMaxScalerModel private[ml] (
   /** @group setParam */
   def setMax(value: Double): this.type = set(max, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray
     val minArray = originalMin.toArray
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 4f67042..9935779 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
 
@@ -121,7 +121,8 @@ class OneHotEncoder(override val uid: String) extends Transformer
     StructType(outputFields)
   }
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     // schema transformation
     val inputColName = $(inputCol)
     val outputColName = $(outputCol)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 305c3d1..9cf722e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -68,7 +68,8 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
   /**
    * Computes a [[PCAModel]] that contains the principal components of the input vectors.
    */
-  override def fit(dataset: DataFrame): PCAModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): PCAModel = {
     transformSchema(dataset.schema, logging = true)
     val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v}
     val pca = new feature.PCA(k = $(k))
@@ -124,7 +125,8 @@ class PCAModel private[ml] (
    * NOTE: Vectors to be transformed must be the same length
    * as the source vectors given to [[PCA.fit()]].
    */
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val pcaModel = new feature.PCAModel($(k), pc, explainedVariance)
     val pcaOp = udf { pcaModel.transform _ }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index e486e92..efe8b93 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -23,10 +23,10 @@ import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml._
 import org.apache.spark.ml.attribute.NominalAttribute
-import org.apache.spark.ml.param.{IntParam, _}
+import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.types.{DoubleType, StructType}
 import org.apache.spark.util.random.XORShiftRandom
 
@@ -87,7 +87,8 @@ final class QuantileDiscretizer(override val uid: String)
     StructType(outputFields)
   }
 
-  override def fit(dataset: DataFrame): Bucketizer = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): Bucketizer = {
     val samples = QuantileDiscretizer
       .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
       .map { case Row(feature: Double) => feature }
@@ -112,13 +113,15 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
   /**
    * Sampling from the given dataset to collect quantile statistics.
    */
-  private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = {
+  private[feature]
+  def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = {
     val totalSamples = dataset.count()
     require(totalSamples > 0,
       "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
     val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
     val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
-    dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
+    dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
+      .collect()
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 12a76db..3ac6c77 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -29,7 +29,7 @@ import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.VectorUDT
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types._
 
 /**
@@ -103,7 +103,8 @@ class RFormula(override val uid: String)
     RFormulaParser.parse($(formula)).hasIntercept
   }
 
-  override def fit(dataset: DataFrame): RFormulaModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): RFormulaModel = {
     require(isDefined(formula), "Formula must be defined first.")
     val parsedFormula = RFormulaParser.parse($(formula))
     val resolvedFormula = parsedFormula.resolve(dataset.schema)
@@ -204,7 +205,8 @@ class RFormulaModel private[feature](
     private[ml] val pipelineModel: PipelineModel)
   extends Model[RFormulaModel] with RFormulaBase with MLWritable {
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     checkCanTransform(dataset.schema)
     transformLabel(pipelineModel.transform(dataset))
   }
@@ -232,10 +234,10 @@ class RFormulaModel private[feature](
 
   override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
 
-  private def transformLabel(dataset: DataFrame): DataFrame = {
+  private def transformLabel(dataset: Dataset[_]): DataFrame = {
     val labelName = resolvedFormula.label
     if (hasLabelCol(dataset.schema)) {
-      dataset
+      dataset.toDF
     } else if (dataset.schema.exists(_.name == labelName)) {
       dataset.schema(labelName).dataType match {
         case _: NumericType | BooleanType =>
@@ -246,7 +248,7 @@ class RFormulaModel private[feature](
     } else {
       // Ignore the label field. This is a hack so that this transformer can also work on test
       // datasets in a Pipeline.
-      dataset
+      dataset.toDF
     }
   }
 
@@ -323,7 +325,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str
   def this(columnsToPrune: Set[String]) =
     this(Identifiable.randomUID("columnPruner"), columnsToPrune)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
     dataset.select(columnsToKeep.map(dataset.col): _*)
   }
@@ -396,7 +398,7 @@ private class VectorAttributeRewriter(
   def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
     this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val metadata = {
       val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
       val attrs = group.attributes.get.map { attr =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index e0ca45b..95fe942 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -63,8 +63,8 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
 
   private val tableIdentifier: String = "__THIS__"
 
-  @Since("1.6.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val tableName = Identifiable.randomUID(uid)
     dataset.registerTempTable(tableName)
     val realStatement = $(statement).replace(tableIdentifier, tableName)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 26ee8e1..118a6e3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -85,7 +85,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
   /** @group setParam */
   def setWithStd(value: Boolean): this.type = set(withStd, value)
 
-  override def fit(dataset: DataFrame): StandardScalerModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): StandardScalerModel = {
     transformSchema(dataset.schema, logging = true)
     val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
     val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
@@ -135,7 +136,8 @@ class StandardScalerModel private[ml] (
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
     val scale = udf { scaler.transform _ }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 0a0e0b0..b96bc48 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
 
@@ -125,7 +125,8 @@ class StopWordsRemover(override val uid: String)
 
   setDefault(stopWords -> StopWords.English, caseSensitive -> false)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val outputSchema = transformSchema(dataset.schema)
     val t = if ($(caseSensitive)) {
         val stopWordsSet = $(stopWords).toSet

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index faa0f6f..7e0d374 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashMap
@@ -80,7 +80,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
 
-  override def fit(dataset: DataFrame): StringIndexerModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): StringIndexerModel = {
     val counts = dataset.select(col($(inputCol)).cast(StringType))
       .rdd
       .map(_.getString(0))
@@ -144,11 +145,12 @@ class StringIndexerModel (
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     if (!dataset.schema.fieldNames.contains($(inputCol))) {
       logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
         "Skip StringIndexerModel.")
-      return dataset
+      return dataset.toDF
     }
     validateAndTransformSchema(dataset.schema)
 
@@ -286,7 +288,8 @@ class IndexToString private[ml] (override val uid: String)
     StructType(outputFields)
   }
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val inputColSchema = dataset.schema($(inputCol))
     // If the labels array is empty use column metadata
     val values = if ($(labels).isEmpty) {

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 957e8e7..4d3e46e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
@@ -47,10 +47,11 @@ class VectorAssembler(override val uid: String)
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     // Schema transformation.
     val schema = dataset.schema
-    lazy val first = dataset.first()
+    lazy val first = dataset.toDF.first()
     val attrs = $(inputCols).flatMap { c =>
       val field = schema(c)
       val index = schema.fieldIndex(c)

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index bf4aef2..68b699d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -31,7 +31,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.udf
 import org.apache.spark.sql.types.{StructField, StructType}
 import org.apache.spark.util.collection.OpenHashSet
@@ -108,7 +108,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def fit(dataset: DataFrame): VectorIndexerModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): VectorIndexerModel = {
     transformSchema(dataset.schema, logging = true)
     val firstRow = dataset.select($(inputCol)).take(1)
     require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.")
@@ -345,7 +346,8 @@ class VectorIndexerModel private[ml] (
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val newField = prepOutputField(dataset.schema)
     val transformUDF = udf { (vector: Vector) => transformFunc(vector) }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index b60e82d..7a9468b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.StructType
 
@@ -89,7 +89,8 @@ final class VectorSlicer(override val uid: String)
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     // Validity checks
     transformSchema(dataset.schema)
     val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 95bae1c..a726929 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
@@ -135,7 +135,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
   /** @group setParam */
   def setMinCount(value: Int): this.type = set(minCount, value)
 
-  override def fit(dataset: DataFrame): Word2VecModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): Word2VecModel = {
     transformSchema(dataset.schema, logging = true)
     val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
     val wordVectors = new feature.Word2Vec()
@@ -219,7 +220,8 @@ class Word2VecModel private[ml] (
    * Transform a sentence column to a vector column to represent the whole sentence. The transform
    * is performed by averaging all word vectors it contains.
    */
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val vectors = wordVectors.getVectors
       .mapValues(vv => Vectors.dense(vv.map(_.toDouble)))

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index 40590e7..2ae4115 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.feature.RFormula
 import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 
 private[r] class AFTSurvivalRegressionWrapper private (
     pipeline: PipelineModel,
@@ -43,7 +43,7 @@ private[r] class AFTSurvivalRegressionWrapper private (
     features ++ Array("Log(scale)")
   }
 
-  def transform(dataset: DataFrame): DataFrame = {
+  def transform(dataset: Dataset[_]): DataFrame = {
     pipeline.transform(dataset)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index ed735a4..ee51357 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
 import org.apache.spark.ml.feature.VectorAssembler
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 
 private[r] class KMeansWrapper private (
     pipeline: PipelineModel) {
@@ -52,7 +52,7 @@ private[r] class KMeansWrapper private (
     }
   }
 
-  def transform(dataset: DataFrame): DataFrame = {
+  def transform(dataset: Dataset[_]): DataFrame = {
     pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index 07383d3..2cd709d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
 import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
 import org.apache.spark.ml.feature.{IndexToString, RFormula}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 
 private[r] class NaiveBayesWrapper private (
     pipeline: PipelineModel,
@@ -36,7 +36,7 @@ private[r] class NaiveBayesWrapper private (
 
   lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp)
 
-  def transform(dataset: DataFrame): DataFrame = {
+  def transform(dataset: Dataset[_]): DataFrame = {
     pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org