You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/11/26 13:29:02 UTC

spark git commit: [SPARK-18481][ML] ML 2.1 QA: Remove deprecated methods for ML

Repository: spark
Updated Branches:
  refs/heads/master a88329d45 -> c4a7eef0c


[SPARK-18481][ML] ML 2.1 QA: Remove deprecated methods for ML

## What changes were proposed in this pull request?
Remove deprecated methods for ML.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <yb...@gmail.com>

Closes #15913 from yanboliang/spark-18481.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c4a7eef0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c4a7eef0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c4a7eef0

Branch: refs/heads/master
Commit: c4a7eef0ce2d305c5c90a0a9a73b5a32eccfba95
Parents: a88329d
Author: Yanbo Liang <yb...@gmail.com>
Authored: Sat Nov 26 05:28:41 2016 -0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Sat Nov 26 05:28:41 2016 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Pipeline.scala    |  4 +
 .../spark/ml/classification/GBTClassifier.scala |  6 ++
 .../ml/classification/LogisticRegression.scala  |  8 +-
 .../classification/RandomForestClassifier.scala | 11 +--
 .../apache/spark/ml/feature/ChiSqSelector.scala |  7 --
 .../org/apache/spark/ml/param/params.scala      | 15 ----
 .../spark/ml/regression/GBTRegressor.scala      |  6 ++
 .../spark/ml/regression/LinearRegression.scala  |  3 -
 .../ml/regression/RandomForestRegressor.scala   | 10 +--
 .../org/apache/spark/ml/tree/treeModels.scala   |  5 --
 .../org/apache/spark/ml/tree/treeParams.scala   | 90 +++++++++-----------
 .../org/apache/spark/ml/util/ReadWrite.scala    |  2 +-
 .../ml/classification/GBTClassifierSuite.scala  |  8 ++
 .../LogisticRegressionSuite.scala               |  6 ++
 project/MimaExcludes.scala                      | 30 +++++++
 python/pyspark/ml/util.py                       | 40 ++++++++-
 16 files changed, 144 insertions(+), 107 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index f406f8c..38176b9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -46,6 +46,10 @@ abstract class PipelineStage extends Params with Logging {
    *
    * Check transform validity and derive the output schema from the input schema.
    *
+   * We check validity for interactions between parameters during `transformSchema` and
+   * raise an exception if any parameter value is invalid. Parameter value checks which
+   * do not depend on other parameters are handled by `Param.validate()`.
+   *
    * Typical implementation should first conduct verification on schema change and parameter
    * validity, including complex parameter interaction checks.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 52f93f5..ca52231 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -203,6 +203,12 @@ class GBTClassificationModel private[ml](
   @Since("1.4.0")
   override def trees: Array[DecisionTreeRegressionModel] = _trees
 
+  /**
+   * Number of trees in ensemble
+   */
+  @Since("2.0.0")
+  val getNumTrees: Int = trees.length
+
   @Since("1.4.0")
   override def treeWeights: Array[Double] = _treeWeights
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index fe29926..41b84f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -40,7 +40,7 @@ import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, lit}
-import org.apache.spark.sql.types.DoubleType
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.VersionUtils
 
@@ -176,8 +176,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
     }
   }
 
-  override def validateParams(): Unit = {
+  override protected def validateAndTransformSchema(
+      schema: StructType,
+      fitting: Boolean,
+      featuresDataType: DataType): StructType = {
     checkThresholdConsistency()
+    super.validateAndTransformSchema(schema, fitting, featuresDataType)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 907c73e..d151213 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -158,7 +158,7 @@ class RandomForestClassificationModel private[ml] (
     @Since("1.6.0") override val numFeatures: Int,
     @Since("1.5.0") override val numClasses: Int)
   extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
-  with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
+  with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
   with MLWritable with Serializable {
 
   require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
@@ -221,15 +221,6 @@ class RandomForestClassificationModel private[ml] (
     }
   }
 
-  /**
-   * Number of trees in ensemble
-   *
-   * @deprecated  Use [[getNumTrees]] instead.  This method will be removed in 2.1.0
-   */
-  // TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
-  @deprecated("Use getNumTrees instead.  This method will be removed in 2.1.0.", "2.0.0")
-  val numTrees: Int = trees.length
-
   @Since("1.4.0")
   override def copy(extra: ParamMap): RandomForestClassificationModel = {
     copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 653fa41..7cd0f15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -216,13 +216,6 @@ final class ChiSqSelectorModel private[ml] (
   @Since("1.6.0")
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  /**
-   * @group setParam
-   */
-  @Since("1.6.0")
-  @deprecated("labelCol is not used by ChiSqSelectorModel.", "2.0.0")
-  def setLabelCol(value: String): this.type = set(labelCol, value)
-
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     val transformedSchema = transformSchema(dataset.schema, logging = true)

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 96206e0..5bd8ebe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -547,21 +547,6 @@ trait Params extends Identifiable with Serializable {
   }
 
   /**
-   * Validates parameter values stored internally.
-   * Raise an exception if any parameter value is invalid.
-   *
-   * This only needs to check for interactions between parameters.
-   * Parameter value checks which do not depend on other parameters are handled by
-   * `Param.validate()`. This method does not handle input/output column parameters;
-   * those are checked during schema validation.
-   * @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
-   */
-  @deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
-  def validateParams(): Unit = {
-    // Do nothing by default.  Override to handle Param interactions.
-  }
-
-  /**
    * Explains a param.
    * @param param input param, must belong to this instance.
    * @return a string that contains the input param name, doc, and optionally its default value and

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index ed2d055..6d8159a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -183,6 +183,12 @@ class GBTRegressionModel private[ml](
   @Since("1.4.0")
   override def trees: Array[DecisionTreeRegressionModel] = _trees
 
+  /**
+   * Number of trees in ensemble
+   */
+  @Since("2.0.0")
+  val getNumTrees: Int = trees.length
+
   @Since("1.4.0")
   override def treeWeights: Array[Double] = _treeWeights
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index eb4e38c..19ddf36 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -611,9 +611,6 @@ class LinearRegressionSummary private[regression] (
     private val privateModel: LinearRegressionModel,
     private val diagInvAtWA: Array[Double]) extends Serializable {
 
-  @deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0")
-  val model: LinearRegressionModel = privateModel
-
   @transient private val metrics = new RegressionMetrics(
     predictions
       .select(col(predictionCol), col(labelCol).cast(DoubleType))

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index d60f05e..90d89c5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -145,7 +145,7 @@ class RandomForestRegressionModel private[ml] (
     private val _trees: Array[DecisionTreeRegressionModel],
     override val numFeatures: Int)
   extends PredictionModel[Vector, RandomForestRegressionModel]
-  with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+  with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
   with MLWritable with Serializable {
 
   require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
@@ -182,14 +182,6 @@ class RandomForestRegressionModel private[ml] (
     _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
   }
 
-  /**
-   * Number of trees in ensemble
-   * @deprecated  Use [[getNumTrees]] instead.  This method will be removed in 2.1.0
-   */
-  // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
-  @deprecated("Use getNumTrees instead.  This method will be removed in 2.1.0.", "2.0.0")
-  val numTrees: Int = trees.length
-
   @Since("1.4.0")
   override def copy(extra: ParamMap): RandomForestRegressionModel = {
     copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index d3cbc36..0d6e903 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -95,11 +95,6 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
   /** Trees in this ensemble. Warning: These have null parent Estimators. */
   def trees: Array[M]
 
-  /**
-   * Number of trees in ensemble
-   */
-  val getNumTrees: Int = trees.length
-
   /** Weights for each tree, zippable with [[trees]] */
   def treeWeights: Array[Double]
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 40510ad8..83ab4b5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -319,8 +319,32 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
   }
 }
 
-/** Used for [[RandomForestParams]] */
-private[ml] trait HasFeatureSubsetStrategy extends Params {
+/**
+ * Parameters for Random Forest algorithms.
+ */
+private[ml] trait RandomForestParams extends TreeEnsembleParams {
+
+  /**
+   * Number of trees to train (>= 1).
+   * If 1, then no bootstrapping is used.  If > 1, then bootstrapping is done.
+   * TODO: Change to always do bootstrapping (simpler).  SPARK-7130
+   * (default = 20)
+   *
+   * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
+   * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
+   * are a bit different.
+   * @group param
+   */
+  final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+    ParamValidators.gtEq(1))
+
+  setDefault(numTrees -> 20)
+
+  /** @group setParam */
+  def setNumTrees(value: Int): this.type = set(numTrees, value)
+
+  /** @group getParam */
+  final def getNumTrees: Int = $(numTrees)
 
   /**
    * The number of features to consider for splits at each tree node.
@@ -366,38 +390,6 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
   final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
 }
 
-/**
- * Used for [[RandomForestParams]].
- * This is separated out from [[RandomForestParams]] because of an issue with the
- * `numTrees` method conflicting with this Param in the Estimator.
- */
-private[ml] trait HasNumTrees extends Params {
-
-  /**
-   * Number of trees to train (>= 1).
-   * If 1, then no bootstrapping is used.  If > 1, then bootstrapping is done.
-   * TODO: Change to always do bootstrapping (simpler).  SPARK-7130
-   * (default = 20)
-   * @group param
-   */
-  final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
-    ParamValidators.gtEq(1))
-
-  setDefault(numTrees -> 20)
-
-  /** @group setParam */
-  def setNumTrees(value: Int): this.type = set(numTrees, value)
-
-  /** @group getParam */
-  final def getNumTrees: Int = $(numTrees)
-}
-
-/**
- * Parameters for Random Forest algorithms.
- */
-private[ml] trait RandomForestParams extends TreeEnsembleParams
-  with HasFeatureSubsetStrategy with HasNumTrees
-
 private[spark] object RandomForestParams {
   // These options should be lowercase.
   final val supportedFeatureSubsetStrategies: Array[String] =
@@ -407,21 +399,15 @@ private[spark] object RandomForestParams {
 private[ml] trait RandomForestClassifierParams
   extends RandomForestParams with TreeClassifierParams
 
-private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
-  with HasFeatureSubsetStrategy with TreeClassifierParams
-
 private[ml] trait RandomForestRegressorParams
   extends RandomForestParams with TreeRegressorParams
 
-private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
-  with HasFeatureSubsetStrategy with TreeRegressorParams
-
 /**
  * Parameters for Gradient-Boosted Tree algorithms.
  *
  * Note: Marked as private and DeveloperApi since this may be made public in the future.
  */
-private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize {
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
 
   /* TODO: Add this doc when we add this param.  SPARK-7132
    * Threshold for stopping early when runWithValidation is used.
@@ -434,24 +420,26 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
   // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
   // validationTol -> 1e-5
 
-  setDefault(maxIter -> 20, stepSize -> 0.1)
-
   /** @group setParam */
   def setMaxIter(value: Int): this.type = set(maxIter, value)
 
   /**
-   * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
-   * estimator.
+   * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking
+   * the contribution of each estimator.
    * (default = 0.1)
-   * @group setParam
+   * @group param
    */
+  final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " +
+    "(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.",
+    ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+  /** @group getParam */
+  final def getStepSize: Double = $(stepSize)
+
+  /** @group setParam */
   def setStepSize(value: Double): this.type = set(stepSize, value)
 
-  override def validateParams(): Unit = {
-    require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)(
-      getStepSize), "GBT parameter stepSize should be in interval (0, 1], " +
-      s"but it given invalid value $getStepSize.")
-  }
+  setDefault(maxIter -> 20, stepSize -> 0.1)
 
   /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
   private[ml] def getOldBoostingStrategy(

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 5b7e5ec..bbb9886 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -46,7 +46,7 @@ private[util] sealed trait BaseReadWrite {
    * Sets the Spark SQLContext to use for saving/loading.
    */
   @Since("1.6.0")
-  @deprecated("Use session instead", "2.0.0")
+  @deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0")
   def context(sqlContext: SQLContext): this.type = {
     optionSparkSession = Option(sqlContext.sparkSession)
     this

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 3492709..7c36745 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -70,6 +70,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
     ParamsSuite.checkParams(model)
   }
 
+  test("GBT parameter stepSize should be in interval (0, 1]") {
+    withClue("GBT parameter stepSize should be in interval (0, 1]") {
+      intercept[IllegalArgumentException] {
+        new GBTClassifier().setStepSize(10)
+      }
+    }
+  }
+
   test("Binary classification with continuous features: Log Loss") {
     val categoricalFeatures = Map.empty[Int, Int]
     testCombinations.foreach {

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index e360542..9c4c59a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -194,6 +194,12 @@ class LogisticRegressionSuite
     // thresholds and threshold must be consistent: values
     withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
       intercept[IllegalArgumentException] {
+        lr2.fit(smallBinaryDataset,
+          lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
+      }
+    }
+    withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
+      intercept[IllegalArgumentException] {
         val lr2model = lr2.fit(smallBinaryDataset,
           lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
         lr2model.getThreshold

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 12f7ed2..8401401 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -867,6 +867,36 @@ object MimaExcludes {
       // [SPARK-12221] Add CPU time to metrics
       ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"),
       ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this")
+    ) ++ Seq(
+      // [SPARK-18481] ML 2.1 QA: Remove deprecated methods for ML
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PipelineStage.validateParams"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.JavaParams.validateParams"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.Params.validateParams"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.validateParams"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegression.validateParams"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.validateParams"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.validateParams"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.setLabelCol"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.validateParams"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.validateParams"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.validateParams"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.model"),
+      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
+      ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassifier"),
+      ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel"),
+      ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassifier"),
+      ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassificationModel"),
+      ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressor"),
+      ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel"),
+      ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressor"),
+      ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressionModel"),
+      ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.getNumTrees"),
+      ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.getNumTrees"),
+      ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
+      ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"),
+      ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
+      ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy")
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/python/pyspark/ml/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 7d39c30..bec4b28 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -78,7 +78,14 @@ class MLWriter(object):
         raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
 
     def context(self, sqlContext):
-        """Sets the SQL context to use for saving."""
+        """
+        Sets the SQL context to use for saving.
+        .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead.
+        """
+        raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+    def session(self, sparkSession):
+        """Sets the Spark Session to use for saving."""
         raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
 
 
@@ -105,10 +112,19 @@ class JavaMLWriter(MLWriter):
         return self
 
     def context(self, sqlContext):
-        """Sets the SQL context to use for saving."""
+        """
+        Sets the SQL context to use for saving.
+        .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead.
+        """
+        warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.")
         self._jwrite.context(sqlContext._ssql_ctx)
         return self
 
+    def session(self, sparkSession):
+        """Sets the Spark Session to use for saving."""
+        self._jwrite.session(sparkSession._jsparkSession)
+        return self
+
 
 @inherit_doc
 class MLWritable(object):
@@ -155,7 +171,14 @@ class MLReader(object):
         raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
 
     def context(self, sqlContext):
-        """Sets the SQL context to use for loading."""
+        """
+        Sets the SQL context to use for loading.
+        .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead.
+        """
+        raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
+
+    def session(self, sparkSession):
+        """Sets the Spark Session to use for loading."""
         raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
 
 
@@ -180,10 +203,19 @@ class JavaMLReader(MLReader):
         return self._clazz._from_java(java_obj)
 
     def context(self, sqlContext):
-        """Sets the SQL context to use for loading."""
+        """
+        Sets the SQL context to use for loading.
+        .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead.
+        """
+        warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.")
         self._jread.context(sqlContext._ssql_ctx)
         return self
 
+    def session(self, sparkSession):
+        """Sets the Spark Session to use for loading."""
+        self._jread.session(sparkSession._jsparkSession)
+        return self
+
     @classmethod
     def _java_loader_class(cls, clazz):
         """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org