You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/27 08:51:35 UTC
spark git commit: [SPARK-7535] [.1] [MLLIB] minor changes to the
pipeline API
Repository: spark
Updated Branches:
refs/heads/master b463e6d61 -> a9f1c0c57
[SPARK-7535] [.1] [MLLIB] minor changes to the pipeline API
1. removed `Params.validateParams(extra)`
2. added `Evaluate.evaluate(dataset, paramPairs*)`
3. updated `RegressionEvaluator` doc
jkbradley
Author: Xiangrui Meng <me...@databricks.com>
Closes #6392 from mengxr/SPARK-7535.1 and squashes the following commits:
5ff5af8 [Xiangrui Meng] add unit test for CV.validateParams
f1f8369 [Xiangrui Meng] update CV.validateParams() to test estimatorParamMaps
607445d [Xiangrui Meng] merge master
8716f5f [Xiangrui Meng] specify default metric name in RegressionEvaluator
e4e5631 [Xiangrui Meng] update RegressionEvaluator doc
801e864 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7535.1
fcbd3e2 [Xiangrui Meng] Merge branch 'master' into SPARK-7535.1
2192316 [Xiangrui Meng] remove validateParams(extra); add evaluate(dataset, extra*)
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a9f1c0c5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a9f1c0c5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a9f1c0c5
Branch: refs/heads/master
Commit: a9f1c0c57b9be586dbada09dab91dcfce31141d9
Parents: b463e6d
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue May 26 23:51:32 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue May 26 23:51:32 2015 -0700
----------------------------------------------------------------------
.../scala/org/apache/spark/ml/Pipeline.scala | 9 ++--
.../ml/evaluation/RegressionEvaluator.scala | 4 +-
.../org/apache/spark/ml/param/params.scala | 13 -----
.../apache/spark/ml/tuning/CrossValidator.scala | 23 ++++++---
.../org/apache/spark/ml/param/ParamsSuite.scala | 2 +-
.../spark/ml/tuning/CrossValidatorSuite.scala | 52 +++++++++++++++++++-
6 files changed, 71 insertions(+), 32 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 9da3ff6..11a4722 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -97,12 +97,9 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
/** @group getParam */
def getStages: Array[PipelineStage] = $(stages).clone()
- override def validateParams(paramMap: ParamMap): Unit = {
- val map = extractParamMap(paramMap)
- getStages.foreach {
- case pStage: Params => pStage.validateParams(map)
- case _ =>
- }
+ override def validateParams(): Unit = {
+ super.validateParams()
+ $(stages).foreach(_.validateParams())
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 1771177..abb1b35 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -36,8 +36,8 @@ final class RegressionEvaluator(override val uid: String)
def this() = this(Identifiable.randomUID("regEval"))
/**
- * param for metric name in evaluation
- * @group param supports mse, rmse, r2, mae as valid metric names.
+ * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
+ * @group param
*/
val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 1afa59c..473488d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -334,19 +334,6 @@ trait Params extends Identifiable with Serializable {
}
/**
- * Validates parameter values stored internally plus the input parameter map.
- * Raises an exception if any parameter is invalid.
- *
- * This only needs to check for interactions between parameters.
- * Parameter value checks which do not depend on other parameters are handled by
- * [[Param.validate()]]. This method does not handle input/output column parameters;
- * those are checked during schema validation.
- */
- def validateParams(paramMap: ParamMap): Unit = {
- copy(paramMap).validateParams()
- }
-
- /**
* Validates parameter values stored internally.
* Raise an exception if any parameter value is invalid.
*
http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 2e5a629..6434b64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -102,12 +102,6 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
/** @group setParam */
def setNumFolds(value: Int): this.type = set(numFolds, value)
- override def validateParams(paramMap: ParamMap): Unit = {
- getEstimatorParamMaps.foreach { eMap =>
- getEstimator.validateParams(eMap ++ paramMap)
- }
- }
-
override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
@@ -147,6 +141,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
override def transformSchema(schema: StructType): StructType = {
$(estimator).transformSchema(schema)
}
+
+ override def validateParams(): Unit = {
+ super.validateParams()
+ val est = $(estimator)
+ for (paramMap <- $(estimatorParamMaps)) {
+ est.copy(paramMap).validateParams()
+ }
+ }
}
/**
@@ -159,8 +161,8 @@ class CrossValidatorModel private[ml] (
val bestModel: Model[_])
extends Model[CrossValidatorModel] with CrossValidatorParams {
- override def validateParams(paramMap: ParamMap): Unit = {
- bestModel.validateParams(paramMap)
+ override def validateParams(): Unit = {
+ bestModel.validateParams()
}
override def transform(dataset: DataFrame): DataFrame = {
@@ -171,4 +173,9 @@ class CrossValidatorModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
bestModel.transformSchema(schema)
}
+
+ override def copy(extra: ParamMap): CrossValidatorModel = {
+ val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
+ copyValues(copied, extra)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index d270ad7..04f2af4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -135,7 +135,7 @@ class ParamsSuite extends FunSuite {
intercept[IllegalArgumentException] {
solver.validateParams()
}
- solver.validateParams(ParamMap(inputCol -> "input"))
+ solver.copy(ParamMap(inputCol -> "input")).validateParams()
solver.setInputCol("input")
assert(solver.isSet(inputCol))
assert(solver.isDefined(inputCol))
http://git-wip-us.apache.org/repos/asf/spark/blob/a9f1c0c5/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 05313d4..65972ec 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -19,11 +19,15 @@ package org.apache.spark.ml.tuning
import org.scalatest.FunSuite
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
-import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, DataFrame}
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.types.StructType
class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
@@ -53,4 +57,48 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
}
+
+ test("validateParams should check estimatorParamMaps") {
+ import CrossValidatorSuite._
+
+ val est = new MyEstimator("est")
+ val eval = new MyEvaluator
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(est.inputCol, Array("input1", "input2"))
+ .build()
+
+ val cv = new CrossValidator()
+ .setEstimator(est)
+ .setEstimatorParamMaps(paramMaps)
+ .setEvaluator(eval)
+
+ cv.validateParams() // This should pass.
+
+ val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
+ cv.setEstimatorParamMaps(invalidParamMaps)
+ intercept[IllegalArgumentException] {
+ cv.validateParams()
+ }
+ }
+}
+
+object CrossValidatorSuite {
+
+ abstract class MyModel extends Model[MyModel]
+
+ class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
+
+ override def validateParams(): Unit = require($(inputCol).nonEmpty)
+
+ override def fit(dataset: DataFrame): MyModel = ???
+
+ override def transformSchema(schema: StructType): StructType = ???
+ }
+
+ class MyEvaluator extends Evaluator {
+
+ override def evaluate(dataset: DataFrame): Double = ???
+
+ override val uid: String = "eval"
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org