You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/04 20:29:06 UTC
[3/3] spark git commit: [SPARK-5956] [MLLIB] Pipeline components
should be copyable.
[SPARK-5956] [MLLIB] Pipeline components should be copyable.
This PR added `copy(extra: ParamMap): Params` to `Params`, which makes a copy of the current instance with a randomly generated uid and some extra param values. With this change, we only need to implement `fit` and `transform` without extra param values given the default implementation of `fit(dataset, extra)`:
~~~scala
def fit(dataset: DataFrame, extra: ParamMap): Model = {
copy(extra).fit(dataset)
}
~~~
Inside `fit` and `transform`, since only the embedded values are used, I added `$` as an alias for `getOrDefault` to make the code easier to read. For example, in `LinearRegression.fit` we have:
~~~scala
val effectiveRegParam = $(regParam) / yStd
val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
~~~
Meta-algorithm like `Pipeline` implements its own `copy(extra)`. So the fitted pipeline model stored all copied stages (no matter whether it is a transformer or a model).
Other changes:
* `Params$.inheritValues` is moved to `Params!.copyValues` and returns the target instance.
* `fittingParamMap` was removed because the `parent` carries this information.
* `validate` was renamed to `validateParams` to be more precise.
TODOs:
* [x] add tests for newly added methods
* [ ] update documentation
jkbradley dbtsai
Author: Xiangrui Meng <me...@databricks.com>
Closes #5820 from mengxr/SPARK-5956 and squashes the following commits:
7bef88d [Xiangrui Meng] address comments
05229c3 [Xiangrui Meng] assert -> assertEquals
b2927b1 [Xiangrui Meng] organize imports
f14456b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5956
93e7924 [Xiangrui Meng] add tests for hasParam & copy
463ecae [Xiangrui Meng] merge master
2b954c3 [Xiangrui Meng] update Binarizer
465dd12 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5956
282a1a8 [Xiangrui Meng] fix test
819dd2d [Xiangrui Meng] merge master
b642872 [Xiangrui Meng] example code runs
5a67779 [Xiangrui Meng] examples compile
c76b4d1 [Xiangrui Meng] fix all unit tests
0f4fd64 [Xiangrui Meng] fix some tests
9286a22 [Xiangrui Meng] copyValues to trained models
53e0973 [Xiangrui Meng] move inheritValues to Params and rename it to copyValues
9ee004e [Xiangrui Meng] merge copy and copyWith; rename validate to validateParams
d882afc [Xiangrui Meng] test compile
f082a31 [Xiangrui Meng] make Params copyable and simply handling of extra params in all spark.ml components
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e0833c59
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e0833c59
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e0833c59
Branch: refs/heads/master
Commit: e0833c5958bbd73ff27cfe6865648d7b6e5a99bc
Parents: 5a1a107
Author: Xiangrui Meng <me...@databricks.com>
Authored: Mon May 4 11:28:59 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon May 4 11:28:59 2015 -0700
----------------------------------------------------------------------
.../examples/ml/JavaDeveloperApiExample.java | 24 ++---
.../examples/ml/JavaSimpleParamsExample.java | 4 +-
.../spark/examples/ml/DecisionTreeExample.scala | 6 +-
.../spark/examples/ml/DeveloperApiExample.scala | 22 ++--
.../apache/spark/examples/ml/GBTExample.scala | 4 +-
.../spark/examples/ml/RandomForestExample.scala | 6 +-
.../spark/examples/ml/SimpleParamsExample.scala | 4 +-
.../scala/org/apache/spark/ml/Estimator.scala | 26 +++--
.../scala/org/apache/spark/ml/Evaluator.scala | 20 +++-
.../main/scala/org/apache/spark/ml/Model.scala | 9 +-
.../scala/org/apache/spark/ml/Pipeline.scala | 106 +++++++++----------
.../scala/org/apache/spark/ml/Transformer.scala | 46 +++++---
.../spark/ml/classification/Classifier.scala | 49 +++------
.../classification/DecisionTreeClassifier.scala | 29 ++---
.../spark/ml/classification/GBTClassifier.scala | 33 +++---
.../ml/classification/LogisticRegression.scala | 58 +++++-----
.../ProbabilisticClassifier.scala | 33 ++----
.../classification/RandomForestClassifier.scala | 31 +++---
.../BinaryClassificationEvaluator.scala | 17 ++-
.../org/apache/spark/ml/feature/Binarizer.scala | 20 ++--
.../org/apache/spark/ml/feature/HashingTF.scala | 10 +-
.../scala/org/apache/spark/ml/feature/IDF.scala | 38 +++----
.../apache/spark/ml/feature/Normalizer.scala | 10 +-
.../spark/ml/feature/PolynomialExpansion.scala | 9 +-
.../spark/ml/feature/StandardScaler.scala | 49 ++++-----
.../apache/spark/ml/feature/StringIndexer.scala | 34 +++---
.../org/apache/spark/ml/feature/Tokenizer.scala | 18 ++--
.../spark/ml/feature/VectorAssembler.scala | 15 ++-
.../apache/spark/ml/feature/VectorIndexer.scala | 74 ++++++-------
.../org/apache/spark/ml/feature/Word2Vec.scala | 62 +++++------
.../spark/ml/impl/estimator/Predictor.scala | 72 ++++---------
.../apache/spark/ml/impl/tree/treeParams.scala | 35 +++---
.../org/apache/spark/ml/param/params.scala | 75 +++++++------
.../ml/param/shared/SharedParamsCodeGen.scala | 5 +-
.../spark/ml/param/shared/sharedParams.scala | 35 +++---
.../apache/spark/ml/recommendation/ALS.scala | 73 ++++++-------
.../ml/regression/DecisionTreeRegressor.scala | 23 ++--
.../spark/ml/regression/GBTRegressor.scala | 30 ++----
.../spark/ml/regression/LinearRegression.scala | 41 ++++---
.../ml/regression/RandomForestRegressor.scala | 27 ++---
.../apache/spark/ml/regression/Regressor.scala | 2 +-
.../apache/spark/ml/tuning/CrossValidator.scala | 51 ++++-----
.../JavaLogisticRegressionSuite.java | 14 +--
.../regression/JavaLinearRegressionSuite.java | 21 ++--
.../ml/tuning/JavaCrossValidatorSuite.java | 6 +-
.../org/apache/spark/ml/PipelineSuite.scala | 26 ++---
.../DecisionTreeClassifierSuite.scala | 4 +-
.../ml/classification/GBTClassifierSuite.scala | 4 +-
.../LogisticRegressionSuite.scala | 18 ++--
.../RandomForestClassifierSuite.scala | 4 +-
.../org/apache/spark/ml/param/ParamsSuite.scala | 13 ++-
.../org/apache/spark/ml/param/TestParams.scala | 14 ++-
.../regression/DecisionTreeRegressorSuite.scala | 4 +-
.../spark/ml/regression/GBTRegressorSuite.scala | 3 +-
.../regression/RandomForestRegressorSuite.scala | 4 +-
.../spark/ml/tuning/CrossValidatorSuite.scala | 6 +-
56 files changed, 671 insertions(+), 805 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index 46377a9..eac4f89 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -28,7 +28,6 @@ import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
-import org.apache.spark.ml.param.Params$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
@@ -129,16 +128,16 @@ class MyJavaLogisticRegression
// This method is used by fit().
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
- public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
+ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
// Extract columns from data using helper method.
- JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
+ JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();
// Do learning to estimate the weight vector.
int numFeatures = oldDataset.take(1).get(0).features().size();
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
// Create a model, and return it.
- return new MyJavaLogisticRegressionModel(this, paramMap, weights);
+ return new MyJavaLogisticRegressionModel(this, weights);
}
}
@@ -155,18 +154,11 @@ class MyJavaLogisticRegressionModel
private MyJavaLogisticRegression parent_;
public MyJavaLogisticRegression parent() { return parent_; }
- private ParamMap fittingParamMap_;
- public ParamMap fittingParamMap() { return fittingParamMap_; }
-
private Vector weights_;
public Vector weights() { return weights_; }
- public MyJavaLogisticRegressionModel(
- MyJavaLogisticRegression parent_,
- ParamMap fittingParamMap_,
- Vector weights_) {
+ public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) {
this.parent_ = parent_;
- this.fittingParamMap_ = fittingParamMap_;
this.weights_ = weights_;
}
@@ -210,10 +202,8 @@ class MyJavaLogisticRegressionModel
* In Java, we have to make this method public since Java does not understand Scala's protected
* modifier.
*/
- public MyJavaLogisticRegressionModel copy() {
- MyJavaLogisticRegressionModel m =
- new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
- Params$.MODULE$.inheritValues(this.extractParamMap(), this, m);
- return m;
+ @Override
+ public MyJavaLogisticRegressionModel copy(ParamMap extra) {
+ return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra);
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index 4e02acc..29158d5 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -71,7 +71,7 @@ public class JavaSimpleParamsExample {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
- System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap());
+ System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());
// We may alternatively specify parameters using a ParamMap.
ParamMap paramMap = new ParamMap();
@@ -87,7 +87,7 @@ public class JavaSimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
- System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap());
+ System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
// Prepare test documents.
List<LabeledPoint> localTest = Lists.newArrayList(
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 9002e99..8340d91 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -276,16 +276,14 @@ object DecisionTreeExample {
// Get the trained Decision Tree from the fitted PipelineModel
algo match {
case "classification" =>
- val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
- dt.asInstanceOf[DecisionTreeClassifier])
+ val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
}
case "regression" =>
- val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
- dt.asInstanceOf[DecisionTreeRegressor])
+ val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeRegressionModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 2245fa4..2a2d067 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -18,13 +18,12 @@
package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
-import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
+import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams}
+import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-
/**
* A simple example demonstrating how to write your own learning algorithm using Estimator,
* Transformer, and other abstractions.
@@ -99,7 +98,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
* class since the maxIter parameter is only used during training (not in the Model).
*/
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
- def getMaxIter: Int = getOrDefault(maxIter)
+ def getMaxIter: Int = $(maxIter)
}
/**
@@ -117,18 +116,16 @@ private class MyLogisticRegression
def setMaxIter(value: Int): this.type = set(maxIter, value)
// This method is used by fit()
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): MyLogisticRegressionModel = {
+ override protected def train(dataset: DataFrame): MyLogisticRegressionModel = {
// Extract columns from data using helper method.
- val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val oldDataset = extractLabeledPoints(dataset)
// Do learning to estimate the weight vector.
val numFeatures = oldDataset.take(1)(0).features.size
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
// Create a model, and return it.
- new MyLogisticRegressionModel(this, paramMap, weights)
+ new MyLogisticRegressionModel(this, weights)
}
}
@@ -139,7 +136,6 @@ private class MyLogisticRegression
*/
private class MyLogisticRegressionModel(
override val parent: MyLogisticRegression,
- override val fittingParamMap: ParamMap,
val weights: Vector)
extends ClassificationModel[Vector, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
@@ -176,9 +172,7 @@ private class MyLogisticRegressionModel(
*
* This is used for the default implementation of [[transform()]].
*/
- override protected def copy(): MyLogisticRegressionModel = {
- val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
- Params.inheritValues(extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): MyLogisticRegressionModel = {
+ copyValues(new MyLogisticRegressionModel(parent, weights), extra)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
index 5fccb14..c5899b6 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -201,14 +201,14 @@ object GBTExample {
// Get the trained GBT from the fitted PipelineModel
algo match {
case "classification" =>
- val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[GBTClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
- val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[GBTRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
index 9b90932..7f88d26 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -209,16 +209,14 @@ object RandomForestExample {
// Get the trained Random Forest from the fitted PipelineModel
algo match {
case "classification" =>
- val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
- dt.asInstanceOf[RandomForestClassifier])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
- val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
- dt.asInstanceOf[RandomForestRegressor])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index bf80514..e8a991f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -63,7 +63,7 @@ object SimpleParamsExample {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
- println("Model 1 was fit using parameters: " + model1.fittingParamMap)
+ println("Model 1 was fit using parameters: " + model1.parent.extractParamMap())
// We may alternatively specify parameters using a ParamMap,
// which supports several methods for specifying parameters.
@@ -78,7 +78,7 @@ object SimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
val model2 = lr.fit(training.toDF(), paramMapCombined)
- println("Model 2 was fit using parameters: " + model2.fittingParamMap)
+ println("Model 2 was fit using parameters: " + model2.parent.extractParamMap())
// Prepare test data.
val test = sc.parallelize(Seq(
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index d6b3503..7f3f326 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -34,13 +34,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* Fits a single model to the input data with optional parameters.
*
* @param dataset input dataset
- * @param paramPairs Optional list of param pairs.
- * These values override any specified in this Estimator's embedded ParamMap.
+ * @param firstParamPair the first param pair, overrides embedded params
+ * @param otherParamPairs other param pairs. These values override any specified in this
+ * Estimator's embedded ParamMap.
* @return fitted model
*/
@varargs
- def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = {
- val map = ParamMap(paramPairs: _*)
+ def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
+ val map = new ParamMap()
+ .put(firstParamPair)
+ .put(otherParamPairs: _*)
fit(dataset, map)
}
@@ -52,12 +55,19 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* These values override any specified in this Estimator's embedded ParamMap.
* @return fitted model
*/
- def fit(dataset: DataFrame, paramMap: ParamMap): M
+ def fit(dataset: DataFrame, paramMap: ParamMap): M = {
+ copy(paramMap).fit(dataset)
+ }
+
+ /**
+ * Fits a model to the input data.
+ */
+ def fit(dataset: DataFrame): M
/**
* Fits multiple models to the input data with multiple sets of parameters.
* The default implementation uses a for loop on each parameter map.
- * Subclasses could overwrite this to optimize multi-model training.
+ * Subclasses could override this to optimize multi-model training.
*
* @param dataset input dataset
* @param paramMaps An array of parameter maps.
@@ -67,4 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}
+
+ override def copy(extra: ParamMap): Estimator[M] = {
+ super.copy(extra).asInstanceOf[Estimator[M]]
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
index 8b4b5fd..5f2f8c9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
@@ -18,8 +18,7 @@
package org.apache.spark.ml
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.sql.DataFrame
/**
@@ -27,7 +26,7 @@ import org.apache.spark.sql.DataFrame
* Abstract class for evaluators that compute metrics from predictions.
*/
@AlphaComponent
-abstract class Evaluator extends Identifiable {
+abstract class Evaluator extends Params {
/**
* Evaluates the output.
@@ -36,5 +35,18 @@ abstract class Evaluator extends Identifiable {
* @param paramMap parameter map that specifies the input columns and output metrics
* @return metric
*/
- def evaluate(dataset: DataFrame, paramMap: ParamMap): Double
+ def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
+ this.copy(paramMap).evaluate(dataset)
+ }
+
+ /**
+ * Evaluates the output.
+ * @param dataset a dataset that contains labels/observations and predictions.
+ * @return metric
+ */
+ def evaluate(dataset: DataFrame): Double
+
+ override def copy(extra: ParamMap): Evaluator = {
+ super.copy(extra).asInstanceOf[Evaluator]
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/Model.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index a491bc7..9974efe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -34,9 +34,8 @@ abstract class Model[M <: Model[M]] extends Transformer {
*/
val parent: Estimator[M]
- /**
- * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model.
- * Note: For ensembles' component Models, this value can be null.
- */
- val fittingParamMap: ParamMap
+ override def copy(extra: ParamMap): M = {
+ // The default implementation of Params.copy doesn't work for models.
+ throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 6bfeecd..33d430f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.param.{Params, Param, ParamMap}
+import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -30,40 +30,41 @@ import org.apache.spark.sql.types.StructType
* A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
*/
@AlphaComponent
-abstract class PipelineStage extends Serializable with Logging {
+abstract class PipelineStage extends Params with Logging {
/**
* :: DeveloperApi ::
*
- * Derives the output schema from the input schema and parameters.
- * The schema describes the columns and types of the data.
- *
- * @param schema Input schema to this stage
- * @param paramMap Parameters passed to this stage
- * @return Output schema from this stage
+ * Derives the output schema from the input schema.
*/
@DeveloperApi
- def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+ def transformSchema(schema: StructType): StructType
/**
+ * :: DeveloperApi ::
+ *
* Derives the output schema from the input schema and parameters, optionally with logging.
*
* This should be optimistic. If it is unclear whether the schema will be valid, then it should
* be assumed valid until proven otherwise.
*/
+ @DeveloperApi
protected def transformSchema(
schema: StructType,
- paramMap: ParamMap,
logging: Boolean): StructType = {
if (logging) {
logDebug(s"Input schema: ${schema.json}")
}
- val outputSchema = transformSchema(schema, paramMap)
+ val outputSchema = transformSchema(schema)
if (logging) {
logDebug(s"Expected output schema: ${outputSchema.json}")
}
outputSchema
}
+
+ override def copy(extra: ParamMap): PipelineStage = {
+ super.copy(extra).asInstanceOf[PipelineStage]
+ }
}
/**
@@ -81,15 +82,22 @@ abstract class PipelineStage extends Serializable with Logging {
@AlphaComponent
class Pipeline extends Estimator[PipelineModel] {
- /** param for pipeline stages */
+ /**
+ * param for pipeline stages
+ * @group param
+ */
val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
+
+ /** @group setParam */
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
- def getStages: Array[PipelineStage] = getOrDefault(stages)
- override def validate(paramMap: ParamMap): Unit = {
+ /** @group getParam */
+ def getStages: Array[PipelineStage] = $(stages).clone()
+
+ override def validateParams(paramMap: ParamMap): Unit = {
val map = extractParamMap(paramMap)
getStages.foreach {
- case pStage: Params => pStage.validate(map)
+ case pStage: Params => pStage.validateParams(map)
case _ =>
}
}
@@ -104,13 +112,11 @@ class Pipeline extends Estimator[PipelineModel] {
* pipeline stages. If there are no stages, the output model acts as an identity transformer.
*
* @param dataset input dataset
- * @param paramMap parameter map
* @return fitted pipeline
*/
- override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val theStages = map(stages)
+ override def fit(dataset: DataFrame): PipelineModel = {
+ transformSchema(dataset.schema, logging = true)
+ val theStages = $(stages)
// Search for the last estimator.
var indexOfLastEstimator = -1
theStages.view.zipWithIndex.foreach { case (stage, index) =>
@@ -126,7 +132,7 @@ class Pipeline extends Estimator[PipelineModel] {
if (index <= indexOfLastEstimator) {
val transformer = stage match {
case estimator: Estimator[_] =>
- estimator.fit(curDataset, paramMap)
+ estimator.fit(curDataset)
case t: Transformer =>
t
case _ =>
@@ -134,7 +140,7 @@ class Pipeline extends Estimator[PipelineModel] {
s"Do not support stage $stage of type ${stage.getClass}")
}
if (index < indexOfLastEstimator) {
- curDataset = transformer.transform(curDataset, paramMap)
+ curDataset = transformer.transform(curDataset)
}
transformers += transformer
} else {
@@ -142,15 +148,20 @@ class Pipeline extends Estimator[PipelineModel] {
}
}
- new PipelineModel(this, map, transformers.toArray)
+ new PipelineModel(this, transformers.toArray)
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- val theStages = map(stages)
+ override def copy(extra: ParamMap): Pipeline = {
+ val map = extractParamMap(extra)
+ val newStages = map(stages).map(_.copy(extra))
+ new Pipeline().setStages(newStages)
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ val theStages = $(stages)
require(theStages.toSet.size == theStages.length,
"Cannot have duplicate components in a pipeline.")
- theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap))
+ theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
}
}
@@ -161,43 +172,24 @@ class Pipeline extends Estimator[PipelineModel] {
@AlphaComponent
class PipelineModel private[ml] (
override val parent: Pipeline,
- override val fittingParamMap: ParamMap,
- private[ml] val stages: Array[Transformer])
+ val stages: Array[Transformer])
extends Model[PipelineModel] with Logging {
- override def validate(paramMap: ParamMap): Unit = {
- val map = fittingParamMap ++ extractParamMap(paramMap)
- stages.foreach(_.validate(map))
+ override def validateParams(): Unit = {
+ super.validateParams()
+ stages.foreach(_.validateParams())
}
- /**
- * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
- * estimator does not exist in the pipeline.
- */
- def getModel[M <: Model[M]](stage: Estimator[M]): M = {
- val matched = stages.filter {
- case m: Model[_] => m.parent.eq(stage)
- case _ => false
- }
- if (matched.isEmpty) {
- throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.")
- } else if (matched.length > 1) {
- throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.")
- } else {
- matched.head.asInstanceOf[M]
- }
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur))
}
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
- val map = fittingParamMap ++ extractParamMap(paramMap)
- transformSchema(dataset.schema, map, logging = true)
- stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
+ override def transformSchema(schema: StructType): StructType = {
+ stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
- val map = fittingParamMap ++ extractParamMap(paramMap)
- stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
+ override def copy(extra: ParamMap): PipelineModel = {
+ new PipelineModel(parent, stages)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 0acda71..d96b54e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -37,13 +37,18 @@ abstract class Transformer extends PipelineStage with Params {
/**
* Transforms the dataset with optional parameters
* @param dataset input dataset
- * @param paramPairs optional list of param pairs, overwrite embedded params
+ * @param firstParamPair the first param pair, overwrite embedded params
+ * @param otherParamPairs other param pairs, overwrite embedded params
* @return transformed dataset
*/
@varargs
- def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = {
+ def transform(
+ dataset: DataFrame,
+ firstParamPair: ParamPair[_],
+ otherParamPairs: ParamPair[_]*): DataFrame = {
val map = new ParamMap()
- paramPairs.foreach(map.put(_))
+ .put(firstParamPair)
+ .put(otherParamPairs: _*)
transform(dataset, map)
}
@@ -53,7 +58,18 @@ abstract class Transformer extends PipelineStage with Params {
* @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame
+ def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ this.copy(paramMap).transform(dataset)
+ }
+
+ /**
+ * Transforms the input dataset.
+ */
+ def transform(dataset: DataFrame): DataFrame
+
+ override def copy(extra: ParamMap): Transformer = {
+ super.copy(extra).asInstanceOf[Transformer]
+ }
}
/**
@@ -74,7 +90,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
* account of the embedded param map. So the param values should be determined solely by the input
* param map.
*/
- protected def createTransformFunc(paramMap: ParamMap): IN => OUT
+ protected def createTransformFunc: IN => OUT
/**
* Returns the data type of the output column.
@@ -86,22 +102,20 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
*/
protected def validateInputType(inputType: DataType): Unit = {}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- val inputType = schema(map(inputCol)).dataType
+ override def transformSchema(schema: StructType): StructType = {
+ val inputType = schema($(inputCol)).dataType
validateInputType(inputType)
- if (schema.fieldNames.contains(map(outputCol))) {
- throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
+ if (schema.fieldNames.contains($(outputCol))) {
+ throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
}
val outputFields = schema.fields :+
- StructField(map(outputCol), outputDataType, nullable = false)
+ StructField($(outputCol), outputDataType, nullable = false)
StructType(outputFields)
}
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- dataset.withColumn(map(outputCol),
- callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol))))
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ dataset.withColumn($(outputCol),
+ callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 29339c9..d3361e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
-import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -27,7 +26,6 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
-
/**
* :: DeveloperApi ::
* Params for classification.
@@ -40,12 +38,10 @@ private[spark] trait ClassifierParams extends PredictorParams
override protected def validateAndTransformSchema(
schema: StructType,
- paramMap: ParamMap,
fitting: Boolean,
featuresDataType: DataType): StructType = {
- val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
- val map = extractParamMap(paramMap)
- SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
+ val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
}
}
@@ -102,27 +98,16 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]].
*
* @param dataset input dataset
- * @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ override def transform(dataset: DataFrame): DataFrame = {
// This default implementation should be overridden as needed.
// Check schema
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
-
- // Prepare model
- val tmpModel = if (paramMap.size != 0) {
- val tmpModel = this.copy()
- Params.inheritValues(paramMap, parent, tmpModel)
- tmpModel
- } else {
- this
- }
+ transformSchema(dataset.schema, logging = true)
val (numColsOutput, outputData) =
- ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map)
+ ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this)
if (numColsOutput == 0) {
logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
@@ -158,7 +143,6 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
*/
@DeveloperApi
protected def predictRaw(features: FeaturesType): Vector
-
}
private[ml] object ClassificationModel {
@@ -167,38 +151,35 @@ private[ml] object ClassificationModel {
* Added prediction column(s). This is separated from [[ClassificationModel.transform()]]
* since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]].
* @param dataset Input dataset
- * @param map Parameter map. This will NOT be merged with the embedded paramMap; the merge
- * should already be done.
* @return (number of columns added, transformed dataset)
*/
def transformColumnsImpl[FeaturesType](
dataset: DataFrame,
- model: ClassificationModel[FeaturesType, _],
- map: ParamMap): (Int, DataFrame) = {
+ model: ClassificationModel[FeaturesType, _]): (Int, DataFrame) = {
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var tmpData = dataset
var numColsOutput = 0
- if (map(model.rawPredictionCol) != "") {
+ if (model.getRawPredictionCol != "") {
// output raw prediction
val features2raw: FeaturesType => Vector = model.predictRaw
- tmpData = tmpData.withColumn(map(model.rawPredictionCol),
- callUDF(features2raw, new VectorUDT, col(map(model.featuresCol))))
+ tmpData = tmpData.withColumn(model.getRawPredictionCol,
+ callUDF(features2raw, new VectorUDT, col(model.getFeaturesCol)))
numColsOutput += 1
- if (map(model.predictionCol) != "") {
+ if (model.getPredictionCol != "") {
val raw2pred: Vector => Double = (rawPred) => {
rawPred.toArray.zipWithIndex.maxBy(_._1)._2
}
- tmpData = tmpData.withColumn(map(model.predictionCol),
- callUDF(raw2pred, DoubleType, col(map(model.rawPredictionCol))))
+ tmpData = tmpData.withColumn(model.getPredictionCol,
+ callUDF(raw2pred, DoubleType, col(model.getRawPredictionCol)))
numColsOutput += 1
}
- } else if (map(model.predictionCol) != "") {
+ } else if (model.getPredictionCol != "") {
// output prediction
val features2pred: FeaturesType => Double = model.predict
- tmpData = tmpData.withColumn(map(model.predictionCol),
- callUDF(features2pred, DoubleType, col(map(model.featuresCol))))
+ tmpData = tmpData.withColumn(model.getPredictionCol,
+ callUDF(features2pred, DoubleType, col(model.getFeaturesCol)))
numColsOutput += 1
}
(numColsOutput, tmpData)
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index ee2a8dc..419e5ba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -18,9 +18,9 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
@@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* :: AlphaComponent ::
*
@@ -64,22 +63,20 @@ final class DecisionTreeClassifier
override def setImpurity(value: String): this.type = super.setImpurity(value)
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): DecisionTreeClassificationModel = {
+ override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
- val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
- s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ s" with invalid label column ${$(labelCol)}, without the number of classes" +
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
val oldModel = OldDecisionTree.train(oldDataset, strategy)
- DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ DecisionTreeClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
/** (private[ml]) Create a Strategy instance to use with the old API. */
@@ -106,7 +103,6 @@ object DecisionTreeClassifier {
@AlphaComponent
final class DecisionTreeClassificationModel private[ml] (
override val parent: DecisionTreeClassifier,
- override val fittingParamMap: ParamMap,
override val rootNode: Node)
extends PredictionModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
@@ -118,10 +114,8 @@ final class DecisionTreeClassificationModel private[ml] (
rootNode.predict(features)
}
- override protected def copy(): DecisionTreeClassificationModel = {
- val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
+ copyValues(new DecisionTreeClassificationModel(parent, rootNode), extra)
}
override def toString: String = {
@@ -140,12 +134,11 @@ private[ml] object DecisionTreeClassificationModel {
def fromOld(
oldModel: OldDecisionTreeModel,
parent: DecisionTreeClassifier,
- fittingParamMap: ParamMap,
categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
require(oldModel.algo == OldAlgo.Classification,
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
- new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+ new DecisionTreeClassificationModel(parent, rootNode)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 3d84986..534ea95 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -23,7 +23,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.param.{Param, Params, ParamMap}
+import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
@@ -31,12 +31,11 @@ import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{Loss => OldLoss, LogLoss => OldLogLoss}
+import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* :: AlphaComponent ::
*
@@ -112,7 +111,7 @@ final class GBTClassifier
def setLossType(value: String): this.type = set(lossType, value)
/** @group getParam */
- def getLossType: String = getOrDefault(lossType).toLowerCase
+ def getLossType: String = $(lossType).toLowerCase
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
@@ -124,25 +123,23 @@ final class GBTClassifier
}
}
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): GBTClassificationModel = {
+ override protected def train(dataset: DataFrame): GBTClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
- val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("GBTClassifier was given input" +
- s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ s" with invalid label column ${$(labelCol)}, without the number of classes" +
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
require(numClasses == 2,
s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val oldGBT = new OldGBT(boostingStrategy)
val oldModel = oldGBT.run(oldDataset)
- GBTClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
}
@@ -165,7 +162,6 @@ object GBTClassifier {
@AlphaComponent
final class GBTClassificationModel(
override val parent: GBTClassifier,
- override val fittingParamMap: ParamMap,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double])
extends PredictionModel[Vector, GBTClassificationModel]
@@ -188,10 +184,8 @@ final class GBTClassificationModel(
if (prediction > 0.0) 1.0 else 0.0
}
- override protected def copy(): GBTClassificationModel = {
- val m = new GBTClassificationModel(parent, fittingParamMap, _trees, _treeWeights)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): GBTClassificationModel = {
+ copyValues(new GBTClassificationModel(parent, _trees, _treeWeights), extra)
}
override def toString: String = {
@@ -210,14 +204,13 @@ private[ml] object GBTClassificationModel {
def fromOld(
oldModel: OldGBTModel,
parent: GBTClassifier,
- fittingParamMap: ParamMap,
categoricalFeatures: Map[Int, Int]): GBTClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
- DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+ DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new GBTClassificationModel(parent, fittingParamMap, newTrees, oldModel.treeWeights)
+ new GBTClassificationModel(parent, newTrees, oldModel.treeWeights)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index cc8b072..b73be03 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -21,12 +21,11 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
-
/**
* Params for logistic regression.
*/
@@ -59,9 +58,9 @@ class LogisticRegression
/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
- override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
+ override protected def train(dataset: DataFrame): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
- val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val oldDataset = extractLabeledPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
@@ -69,17 +68,17 @@ class LogisticRegression
// Train model
val lr = new LogisticRegressionWithLBFGS()
- .setIntercept(paramMap(fitIntercept))
+ .setIntercept($(fitIntercept))
lr.optimizer
- .setRegParam(paramMap(regParam))
- .setNumIterations(paramMap(maxIter))
+ .setRegParam($(regParam))
+ .setNumIterations($(maxIter))
val oldModel = lr.run(oldDataset)
- val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept)
+ val lrm = new LogisticRegressionModel(this, oldModel.weights, oldModel.intercept)
if (handlePersistence) {
oldDataset.unpersist()
}
- lrm
+ copyValues(lrm)
}
}
@@ -92,7 +91,6 @@ class LogisticRegression
@AlphaComponent
class LogisticRegressionModel private[ml] (
override val parent: LogisticRegression,
- override val fittingParamMap: ParamMap,
val weights: Vector,
val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
@@ -110,16 +108,14 @@ class LogisticRegressionModel private[ml] (
1.0 / (1.0 + math.exp(-m))
}
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ override def transform(dataset: DataFrame): DataFrame = {
// This is overridden (a) to be more efficient (avoiding re-computing values when creating
// multiple output columns) and (b) to handle threshold, which the abstractions do not use.
// TODO: We should abstract away the steps defined by UDFs below so that the abstractions
// can call whichever UDFs are needed to create the output columns.
// Check schema
- transformSchema(dataset.schema, paramMap, logging = true)
-
- val map = extractParamMap(paramMap)
+ transformSchema(dataset.schema, logging = true)
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
@@ -128,41 +124,41 @@ class LogisticRegressionModel private[ml] (
// prediction (max margin)
var tmpData = dataset
var numColsOutput = 0
- if (map(rawPredictionCol) != "") {
+ if ($(rawPredictionCol) != "") {
val features2raw: Vector => Vector = (features) => predictRaw(features)
- tmpData = tmpData.withColumn(map(rawPredictionCol),
- callUDF(features2raw, new VectorUDT, col(map(featuresCol))))
+ tmpData = tmpData.withColumn($(rawPredictionCol),
+ callUDF(features2raw, new VectorUDT, col($(featuresCol))))
numColsOutput += 1
}
- if (map(probabilityCol) != "") {
- if (map(rawPredictionCol) != "") {
+ if ($(probabilityCol) != "") {
+ if ($(rawPredictionCol) != "") {
val raw2prob = udf { (rawPreds: Vector) =>
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
Vectors.dense(1.0 - prob1, prob1): Vector
}
- tmpData = tmpData.withColumn(map(probabilityCol), raw2prob(col(map(rawPredictionCol))))
+ tmpData = tmpData.withColumn($(probabilityCol), raw2prob(col($(rawPredictionCol))))
} else {
val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector }
- tmpData = tmpData.withColumn(map(probabilityCol), features2prob(col(map(featuresCol))))
+ tmpData = tmpData.withColumn($(probabilityCol), features2prob(col($(featuresCol))))
}
numColsOutput += 1
}
- if (map(predictionCol) != "") {
- val t = map(threshold)
- if (map(probabilityCol) != "") {
+ if ($(predictionCol) != "") {
+ val t = $(threshold)
+ if ($(probabilityCol) != "") {
val predict = udf { probs: Vector =>
if (probs(1) > t) 1.0 else 0.0
}
- tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(probabilityCol))))
- } else if (map(rawPredictionCol) != "") {
+ tmpData = tmpData.withColumn($(predictionCol), predict(col($(probabilityCol))))
+ } else if ($(rawPredictionCol) != "") {
val predict = udf { rawPreds: Vector =>
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
if (prob1 > t) 1.0 else 0.0
}
- tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(rawPredictionCol))))
+ tmpData = tmpData.withColumn($(predictionCol), predict(col($(rawPredictionCol))))
} else {
val predict = udf { features: Vector => this.predict(features) }
- tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(featuresCol))))
+ tmpData = tmpData.withColumn($(predictionCol), predict(col($(featuresCol))))
}
numColsOutput += 1
}
@@ -193,9 +189,7 @@ class LogisticRegressionModel private[ml] (
Vectors.dense(0.0, m)
}
- override protected def copy(): LogisticRegressionModel = {
- val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): LogisticRegressionModel = {
+ copyValues(new LogisticRegressionModel(parent, weights, intercept), extra)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 1040454..8519841 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -18,7 +18,6 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -34,12 +33,10 @@ private[classification] trait ProbabilisticClassifierParams
override protected def validateAndTransformSchema(
schema: StructType,
- paramMap: ParamMap,
fitting: Boolean,
featuresDataType: DataType): StructType = {
- val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
- val map = extractParamMap(paramMap)
- SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT)
+ val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
+ SchemaUtils.appendColumn(parentSchema, $(probabilityCol), new VectorUDT)
}
}
@@ -95,36 +92,22 @@ private[spark] abstract class ProbabilisticClassificationModel[
* - probability of each class as [[probabilityCol]] of type [[Vector]].
*
* @param dataset input dataset
- * @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ override def transform(dataset: DataFrame): DataFrame = {
// This default implementation should be overridden as needed.
// Check schema
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
-
- // Prepare model
- val tmpModel = if (paramMap.size != 0) {
- val tmpModel = this.copy()
- Params.inheritValues(paramMap, parent, tmpModel)
- tmpModel
- } else {
- this
- }
+ transformSchema(dataset.schema, logging = true)
val (numColsOutput, outputData) =
- ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map)
+ ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this)
// Output selected columns only.
- if (map(probabilityCol) != "") {
+ if ($(probabilityCol) != "") {
// output probabilities
- val features2probs: FeaturesType => Vector = (features) => {
- tmpModel.predictProbabilities(features)
- }
- outputData.withColumn(map(probabilityCol),
- callUDF(features2probs, new VectorUDT, col(map(featuresCol))))
+ outputData.withColumn($(probabilityCol),
+ callUDF(predictProbabilities _, new VectorUDT, col($(featuresCol))))
} else {
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index cfd6508..17f59bb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -22,18 +22,17 @@ import scala.collection.mutable
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* :: AlphaComponent ::
*
@@ -81,24 +80,22 @@ final class RandomForestClassifier
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): RandomForestClassificationModel = {
+ override protected def train(dataset: DataFrame): RandomForestClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
- val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("RandomForestClassifier was given input" +
- s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ s" with invalid label column ${$(labelCol)}, without the number of classes" +
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
val oldModel = OldRandomForest.trainClassifier(
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
- RandomForestClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
}
@@ -123,7 +120,6 @@ object RandomForestClassifier {
@AlphaComponent
final class RandomForestClassificationModel private[ml] (
override val parent: RandomForestClassifier,
- override val fittingParamMap: ParamMap,
private val _trees: Array[DecisionTreeClassificationModel])
extends PredictionModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
@@ -150,10 +146,8 @@ final class RandomForestClassificationModel private[ml] (
votes.maxBy(_._2)._1
}
- override protected def copy(): RandomForestClassificationModel = {
- val m = new RandomForestClassificationModel(parent, fittingParamMap, _trees)
- Params.inheritValues(this.extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): RandomForestClassificationModel = {
+ copyValues(new RandomForestClassificationModel(parent, _trees), extra)
}
override def toString: String = {
@@ -172,14 +166,13 @@ private[ml] object RandomForestClassificationModel {
def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
- fittingParamMap: ParamMap,
categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent, fittingParamMap for each tree is null since there are no good ways to set these.
- DecisionTreeClassificationModel.fromOld(tree, null, null, categoricalFeatures)
+ DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestClassificationModel(parent, fittingParamMap, newTrees)
+ new RandomForestClassificationModel(parent, newTrees)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index c865eb9..e5a73c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -33,8 +33,7 @@ import org.apache.spark.sql.types.DoubleType
* Evaluator for binary classification, which expects two input columns: score and label.
*/
@AlphaComponent
-class BinaryClassificationEvaluator extends Evaluator with Params
- with HasRawPredictionCol with HasLabelCol {
+class BinaryClassificationEvaluator extends Evaluator with HasRawPredictionCol with HasLabelCol {
/**
* param for metric name in evaluation
@@ -44,7 +43,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
"metric name in evaluation (areaUnderROC|areaUnderPR)")
/** @group getParam */
- def getMetricName: String = getOrDefault(metricName)
+ def getMetricName: String = $(metricName)
/** @group setParam */
def setMetricName(value: String): this.type = set(metricName, value)
@@ -57,20 +56,18 @@ class BinaryClassificationEvaluator extends Evaluator with Params
setDefault(metricName -> "areaUnderROC")
- override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
- val map = extractParamMap(paramMap)
-
+ override def evaluate(dataset: DataFrame): Double = {
val schema = dataset.schema
- SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT)
- SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType)
+ SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT)
+ SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
- val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol))
+ val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol))
.map { case Row(rawPrediction: Vector, label: Double) =>
(rawPrediction(1), label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
- val metric = map(metricName) match {
+ val metric = $(metricName) match {
case "areaUnderROC" =>
metrics.areaUnderROC()
case "areaUnderPR" =>
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index f3ce6df..6eb1db6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -44,7 +44,7 @@ final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
new DoubleParam(this, "threshold", "threshold used to binarize continuous features")
/** @group getParam */
- def getThreshold: Double = getOrDefault(threshold)
+ def getThreshold: Double = $(threshold)
/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
@@ -57,23 +57,21 @@ final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val td = map(threshold)
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ val td = $(threshold)
val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
- val outputColName = map(outputCol)
+ val outputColName = $(outputCol)
val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
dataset.select(col("*"),
- binarizer(col(map(inputCol))).as(outputColName, metadata))
+ binarizer(col($(inputCol))).as(outputColName, metadata))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType)
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
- val outputColName = map(outputCol)
+ val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 0b3128f..c305a81 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -19,9 +19,9 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
+import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
/**
@@ -42,13 +42,13 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
setDefault(numFeatures -> (1 << 18))
/** @group getParam */
- def getNumFeatures: Int = getOrDefault(numFeatures)
+ def getNumFeatures: Int = $(numFeatures)
/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
- override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
- val hashingTF = new feature.HashingTF(paramMap(numFeatures))
+ override protected def createTransformFunc: Iterable[_] => Vector = {
+ val hashingTF = new feature.HashingTF($(numFeatures))
hashingTF.transform
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index e6a62d9..d901a20 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -43,7 +43,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
setDefault(minDocFreq -> 0)
/** @group getParam */
- def getMinDocFreq: Int = getOrDefault(minDocFreq)
+ def getMinDocFreq: Int = $(minDocFreq)
/** @group setParam */
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
@@ -51,10 +51,9 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
/**
* Validate and transform the input schema.
*/
- protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT)
- SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
}
@@ -71,18 +70,15 @@ final class IDF extends Estimator[IDFModel] with IDFBase {
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
- val idf = new feature.IDF(map(minDocFreq)).fit(input)
- val model = new IDFModel(this, map, idf)
- Params.inheritValues(map, this, model)
- model
+ override def fit(dataset: DataFrame): IDFModel = {
+ transformSchema(dataset.schema, logging = true)
+ val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
+ val idf = new feature.IDF($(minDocFreq)).fit(input)
+ copyValues(new IDFModel(this, idf))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
@@ -93,7 +89,6 @@ final class IDF extends Estimator[IDFModel] with IDFBase {
@AlphaComponent
class IDFModel private[ml] (
override val parent: IDF,
- override val fittingParamMap: ParamMap,
idfModel: feature.IDFModel)
extends Model[IDFModel] with IDFBase {
@@ -103,14 +98,13 @@ class IDFModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val idf = udf { vec: Vector => idfModel.transform(vec) }
- dataset.withColumn(map(outputCol), idf(col(map(inputCol))))
+ dataset.withColumn($(outputCol), idf(col($(inputCol))))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index bd2b5f6..755b46a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -19,9 +19,9 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{ParamValidators, DoubleParam, ParamMap}
+import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
/**
@@ -41,13 +41,13 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
setDefault(p -> 2.0)
/** @group getParam */
- def getP: Double = getOrDefault(p)
+ def getP: Double = $(p)
/** @group setParam */
def setP(value: Double): this.type = set(p, value)
- override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = {
- val normalizer = new feature.Normalizer(paramMap(p))
+ override protected def createTransformFunc: Vector => Vector = {
+ val normalizer = new feature.Normalizer($(p))
normalizer.transform
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index 1b7c939..63e190c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
+import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.types.DataType
@@ -47,14 +47,13 @@ class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExp
setDefault(degree -> 2)
/** @group getParam */
- def getDegree: Int = getOrDefault(degree)
+ def getDegree: Int = $(degree)
/** @group setParam */
def setDegree(value: Int): this.type = set(degree, value)
- override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { v =>
- val d = paramMap(degree)
- PolynomialExpansion.expand(v, d)
+ override protected def createTransformFunc: Vector => Vector = { v =>
+ PolynomialExpansion.expand(v, $(degree))
}
override protected def outputDataType: DataType = new VectorUDT()
http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index a0e9ed3..7cad59f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -71,25 +71,21 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
/** @group setParam */
def setWithStd(value: Boolean): this.type = set(withStd, value)
- override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
- val scaler = new feature.StandardScaler(withMean = map(withMean), withStd = map(withStd))
+ override def fit(dataset: DataFrame): StandardScalerModel = {
+ transformSchema(dataset.schema, logging = true)
+ val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
+ val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
val scalerModel = scaler.fit(input)
- val model = new StandardScalerModel(this, map, scalerModel)
- Params.inheritValues(map, this, model)
- model
+ copyValues(new StandardScalerModel(this, scalerModel))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- val inputType = schema(map(inputCol)).dataType
+ override def transformSchema(schema: StructType): StructType = {
+ val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
- s"Input column ${map(inputCol)} must be a vector column")
- require(!schema.fieldNames.contains(map(outputCol)),
- s"Output column ${map(outputCol)} already exists.")
- val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
+ s"Input column ${$(inputCol)} must be a vector column")
+ require(!schema.fieldNames.contains($(outputCol)),
+ s"Output column ${$(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}
}
@@ -101,7 +97,6 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
@AlphaComponent
class StandardScalerModel private[ml] (
override val parent: StandardScaler,
- override val fittingParamMap: ParamMap,
scaler: feature.StandardScalerModel)
extends Model[StandardScalerModel] with StandardScalerParams {
@@ -111,21 +106,19 @@ class StandardScalerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val scale = udf((v: Vector) => { scaler.transform(v) } : Vector)
- dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ val scale = udf { scaler.transform _ }
+ dataset.withColumn($(outputCol), scale(col($(inputCol))))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- val inputType = schema(map(inputCol)).dataType
+ override def transformSchema(schema: StructType): StructType = {
+ val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
- s"Input column ${map(inputCol)} must be a vector column")
- require(!schema.fieldNames.contains(map(outputCol)),
- s"Output column ${map(outputCol)} already exists.")
- val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
+ s"Input column ${$(inputCol)} must be a vector column")
+ require(!schema.fieldNames.contains($(outputCol)),
+ s"Output column ${$(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org