You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/27 08:51:35 UTC

spark git commit: [SPARK-7535] [.1] [MLLIB] minor changes to the pipeline API

Repository: spark
Updated Branches:
  refs/heads/master b463e6d61 -> a9f1c0c57


[SPARK-7535] [.1] [MLLIB] minor changes to the pipeline API

1. removed `Params.validateParams(extra)`
2. added `Evaluate.evaluate(dataset, paramPairs*)`
3. updated `RegressionEvaluator` doc

jkbradley

Author: Xiangrui Meng <me...@databricks.com>

Closes #6392 from mengxr/SPARK-7535.1 and squashes the following commits:

5ff5af8 [Xiangrui Meng] add unit test for CV.validateParams
f1f8369 [Xiangrui Meng] update CV.validateParams() to test estimatorParamMaps
607445d [Xiangrui Meng] merge master
8716f5f [Xiangrui Meng] specify default metric name in RegressionEvaluator
e4e5631 [Xiangrui Meng] update RegressionEvaluator doc
801e864 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7535.1
fcbd3e2 [Xiangrui Meng] Merge branch 'master' into SPARK-7535.1
2192316 [Xiangrui Meng] remove validateParams(extra); add evaluate(dataset, extra*)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a9f1c0c5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a9f1c0c5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a9f1c0c5

Branch: refs/heads/master
Commit: a9f1c0c57b9be586dbada09dab91dcfce31141d9
Parents: b463e6d
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue May 26 23:51:32 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue May 26 23:51:32 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Pipeline.scala    |  9 ++--
 .../ml/evaluation/RegressionEvaluator.scala     |  4 +-
 .../org/apache/spark/ml/param/params.scala      | 13 -----
 .../apache/spark/ml/tuning/CrossValidator.scala | 23 ++++++---
 .../org/apache/spark/ml/param/ParamsSuite.scala |  2 +-
 .../spark/ml/tuning/CrossValidatorSuite.scala   | 52 +++++++++++++++++++-
 6 files changed, 71 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 9da3ff6..11a4722 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -97,12 +97,9 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
   /** @group getParam */
   def getStages: Array[PipelineStage] = $(stages).clone()
 
-  override def validateParams(paramMap: ParamMap): Unit = {
-    val map = extractParamMap(paramMap)
-    getStages.foreach {
-      case pStage: Params => pStage.validateParams(map)
-      case _ =>
-    }
+  override def validateParams(): Unit = {
+    super.validateParams()
+    $(stages).foreach(_.validateParams())
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 1771177..abb1b35 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -36,8 +36,8 @@ final class RegressionEvaluator(override val uid: String)
   def this() = this(Identifiable.randomUID("regEval"))
 
   /**
-   * param for metric name in evaluation
-   * @group param supports mse, rmse, r2, mae as valid metric names.
+   * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
+   * @group param
    */
   val metricName: Param[String] = {
     val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))

http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 1afa59c..473488d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -334,19 +334,6 @@ trait Params extends Identifiable with Serializable {
   }
 
   /**
-   * Validates parameter values stored internally plus the input parameter map.
-   * Raises an exception if any parameter is invalid.
-   *
-   * This only needs to check for interactions between parameters.
-   * Parameter value checks which do not depend on other parameters are handled by
-   * [[Param.validate()]].  This method does not handle input/output column parameters;
-   * those are checked during schema validation.
-   */
-  def validateParams(paramMap: ParamMap): Unit = {
-    copy(paramMap).validateParams()
-  }
-
-  /**
    * Validates parameter values stored internally.
    * Raise an exception if any parameter value is invalid.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 2e5a629..6434b64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -102,12 +102,6 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
   /** @group setParam */
   def setNumFolds(value: Int): this.type = set(numFolds, value)
 
-  override def validateParams(paramMap: ParamMap): Unit = {
-    getEstimatorParamMaps.foreach { eMap =>
-      getEstimator.validateParams(eMap ++ paramMap)
-    }
-  }
-
   override def fit(dataset: DataFrame): CrossValidatorModel = {
     val schema = dataset.schema
     transformSchema(schema, logging = true)
@@ -147,6 +141,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
   override def transformSchema(schema: StructType): StructType = {
     $(estimator).transformSchema(schema)
   }
+
+  override def validateParams(): Unit = {
+    super.validateParams()
+    val est = $(estimator)
+    for (paramMap <- $(estimatorParamMaps)) {
+      est.copy(paramMap).validateParams()
+    }
+  }
 }
 
 /**
@@ -159,8 +161,8 @@ class CrossValidatorModel private[ml] (
     val bestModel: Model[_])
   extends Model[CrossValidatorModel] with CrossValidatorParams {
 
-  override def validateParams(paramMap: ParamMap): Unit = {
-    bestModel.validateParams(paramMap)
+  override def validateParams(): Unit = {
+    bestModel.validateParams()
   }
 
   override def transform(dataset: DataFrame): DataFrame = {
@@ -171,4 +173,9 @@ class CrossValidatorModel private[ml] (
   override def transformSchema(schema: StructType): StructType = {
     bestModel.transformSchema(schema)
   }
+
+  override def copy(extra: ParamMap): CrossValidatorModel = {
+    val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
+    copyValues(copied, extra)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index d270ad7..04f2af4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -135,7 +135,7 @@ class ParamsSuite extends FunSuite {
     intercept[IllegalArgumentException] {
       solver.validateParams()
     }
-    solver.validateParams(ParamMap(inputCol -> "input"))
+    solver.copy(ParamMap(inputCol -> "input")).validateParams()
     solver.setInputCol("input")
     assert(solver.isSet(inputCol))
     assert(solver.isDefined(inputCol))

http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 05313d4..65972ec 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -19,11 +19,15 @@ package org.apache.spark.ml.tuning
 
 import org.scalatest.FunSuite
 
+import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.classification.LogisticRegression
-import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, DataFrame}
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.types.StructType
 
 class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
 
@@ -53,4 +57,48 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
     assert(parent.getRegParam === 0.001)
     assert(parent.getMaxIter === 10)
   }
+
+  test("validateParams should check estimatorParamMaps") {
+    import CrossValidatorSuite._
+
+    val est = new MyEstimator("est")
+    val eval = new MyEvaluator
+    val paramMaps = new ParamGridBuilder()
+      .addGrid(est.inputCol, Array("input1", "input2"))
+      .build()
+
+    val cv = new CrossValidator()
+      .setEstimator(est)
+      .setEstimatorParamMaps(paramMaps)
+      .setEvaluator(eval)
+
+    cv.validateParams() // This should pass.
+
+    val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
+    cv.setEstimatorParamMaps(invalidParamMaps)
+    intercept[IllegalArgumentException] {
+      cv.validateParams()
+    }
+  }
+}
+
+object CrossValidatorSuite {
+
+  abstract class MyModel extends Model[MyModel]
+
+  class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
+
+    override def validateParams(): Unit = require($(inputCol).nonEmpty)
+
+    override def fit(dataset: DataFrame): MyModel = ???
+
+    override def transformSchema(schema: StructType): StructType = ???
+  }
+
+  class MyEvaluator extends Evaluator {
+
+    override def evaluate(dataset: DataFrame): Double = ???
+
+    override val uid: String = "eval"
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org