You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/07 01:15:55 UTC
[1/2] spark git commit: [SPARK-5995] [ML] Make Prediction dev API
public
Repository: spark
Updated Branches:
refs/heads/master 774099670 -> 1ad04dae0
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 188d1e5..f6bcdf8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -587,6 +587,28 @@ class DenseVector(val values: Array[Double]) extends Vector {
}
new SparseVector(size, ii, vv)
}
+
+ /**
+ * Find the index of a maximal element. Returns the first maximal element in case of a tie.
+ * Returns -1 if vector has length 0.
+ */
+ private[spark] def argmax: Int = {
+ if (size == 0) {
+ -1
+ } else {
+ var maxIdx = 0
+ var maxValue = values(0)
+ var i = 1
+ while (i < size) {
+ if (values(i) > maxValue) {
+ maxIdx = i
+ maxValue = values(i)
+ }
+ i += 1
+ }
+ maxIdx
+ }
+ }
}
object DenseVector {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org
[2/2] spark git commit: [SPARK-5995] [ML] Make Prediction dev API
public
Posted by me...@apache.org.
[SPARK-5995] [ML] Make Prediction dev API public
Changes:
* Update protected prediction methods, following design doc. **<--most interesting change**
* Changed abstract classes for Estimator and Model to be public. Added DeveloperApi tag. (I kept the traits for Estimator/Model Params private.)
* Changed ProbabilisticClassificationModel method names to use probability instead of probabilities.
CC: mengxr shivaram etrain
Author: Joseph K. Bradley <jo...@databricks.com>
Closes #5913 from jkbradley/public-dev-api and squashes the following commits:
e9aa0ea [Joseph K. Bradley] moved findMax to DenseVector and renamed to argmax. fixed bug for vector of length 0
15b9957 [Joseph K. Bradley] renamed probabilities to probability in method names
5cda84d [Joseph K. Bradley] regenerated sharedParams
7d1877a [Joseph K. Bradley] Made spark.ml prediction abstractions public. Organized their prediction methods for efficient computation of multiple output columns.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1ad04dae
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1ad04dae
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1ad04dae
Branch: refs/heads/master
Commit: 1ad04dae038673a448f529c39b17817b78d6acd0
Parents: 7740996
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Wed May 6 16:15:51 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed May 6 16:15:51 2015 -0700
----------------------------------------------------------------------
.../scala/org/apache/spark/ml/Predictor.scala | 191 ++++++++
.../spark/ml/classification/Classifier.scala | 110 ++---
.../classification/DecisionTreeClassifier.scala | 5 +-
.../spark/ml/classification/GBTClassifier.scala | 5 +-
.../ml/classification/LogisticRegression.scala | 100 ++---
.../ProbabilisticClassifier.scala | 100 +++--
.../classification/RandomForestClassifier.scala | 5 +-
.../spark/ml/impl/estimator/Predictor.scala | 217 ----------
.../apache/spark/ml/impl/tree/treeParams.scala | 431 -------------------
.../ml/param/shared/SharedParamsCodeGen.scala | 6 +-
.../spark/ml/param/shared/sharedParams.scala | 4 +-
.../ml/regression/DecisionTreeRegressor.scala | 5 +-
.../spark/ml/regression/GBTRegressor.scala | 5 +-
.../spark/ml/regression/LinearRegression.scala | 5 +-
.../ml/regression/RandomForestRegressor.scala | 5 +-
.../apache/spark/ml/regression/Regressor.scala | 42 +-
.../org/apache/spark/ml/tree/treeParams.scala | 431 +++++++++++++++++++
.../org/apache/spark/mllib/linalg/Vectors.scala | 22 +
18 files changed, 814 insertions(+), 875 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
new file mode 100644
index 0000000..0e53877
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+import org.apache.spark.sql.{DataFrame, Row}
+
+/**
+ * (private[ml]) Trait for parameters for prediction (regression and classification).
+ */
+private[ml] trait PredictorParams extends Params
+ with HasLabelCol with HasFeaturesCol with HasPredictionCol {
+
+ /**
+ * Validates and transforms the input schema with the provided param map.
+ * @param schema input schema
+ * @param fitting whether this is in fitting
+ * @param featuresDataType SQL DataType for FeaturesType.
+ * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
+ SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
+ if (fitting) {
+ // TODO: Allow other numeric types
+ SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ }
+ SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Abstraction for prediction problems (regression and classification).
+ *
+ * @tparam FeaturesType Type of features.
+ * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @tparam Learner Specialization of this class. If you subclass this type, use this type
+ * parameter to specify the concrete type.
+ * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
+ * parameter to specify the concrete type for the corresponding model.
+ */
+@DeveloperApi
+abstract class Predictor[
+ FeaturesType,
+ Learner <: Predictor[FeaturesType, Learner, M],
+ M <: PredictionModel[FeaturesType, M]]
+ extends Estimator[M] with PredictorParams {
+
+ /** @group setParam */
+ def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
+
+ /** @group setParam */
+ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
+
+ override def fit(dataset: DataFrame): M = {
+ // This handles a few items such as schema validation.
+ // Developers only need to implement train().
+ transformSchema(dataset.schema, logging = true)
+ copyValues(train(dataset))
+ }
+
+ override def copy(extra: ParamMap): Learner = {
+ super.copy(extra).asInstanceOf[Learner]
+ }
+
+ /**
+ * Train a model using the given dataset and parameters.
+ * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
+ * and copying parameters into the model.
+ *
+ * @param dataset Training dataset
+ * @return Fitted model
+ */
+ protected def train(dataset: DataFrame): M
+
+ /**
+ * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+ *
+ * This is used by [[validateAndTransformSchema()]].
+ * This workaround is needed since SQL has different APIs for Scala and Java.
+ *
+ * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+ */
+ protected def featuresDataType: DataType = new VectorUDT
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = true, featuresDataType)
+ }
+
+ /**
+ * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
+ * and put it in an RDD with strong types.
+ */
+ protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
+ dataset.select($(labelCol), $(featuresCol))
+ .map { case Row(label: Double, features: Vector) =>
+ LabeledPoint(label, features)
+ }
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Abstraction for a model for prediction tasks (regression and classification).
+ *
+ * @tparam FeaturesType Type of features.
+ * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
+ * parameter to specify the concrete type for the corresponding model.
+ */
+@DeveloperApi
+abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
+ extends Model[M] with PredictorParams {
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
+
+ /** @group setParam */
+ def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
+
+ /**
+ * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+ *
+ * This is used by [[validateAndTransformSchema()]].
+ * This workaround is needed since SQL has different APIs for Scala and Java.
+ *
+ * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+ */
+ protected def featuresDataType: DataType = new VectorUDT
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = false, featuresDataType)
+ }
+
+ /**
+ * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing
+ * the predictions as a new column [[predictionCol]].
+ *
+ * @param dataset input dataset
+ * @return transformed dataset with [[predictionCol]] of type [[Double]]
+ */
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ if ($(predictionCol).nonEmpty) {
+ dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
+ } else {
+ this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
+ " since no output columns were set.")
+ dataset
+ }
+ }
+
+ /**
+ * Predict label for the given features.
+ * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+ */
+ protected def predict(features: FeaturesType): Double
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index d3361e2..263d580 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -17,8 +17,8 @@
package org.apache.spark.ml.classification
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -26,15 +26,12 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
/**
- * :: DeveloperApi ::
- * Params for classification.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ * (private[spark]) Params for classification.
*/
-@DeveloperApi
-private[spark] trait ClassifierParams extends PredictorParams
- with HasRawPredictionCol {
+private[spark] trait ClassifierParams
+ extends PredictorParams with HasRawPredictionCol {
override protected def validateAndTransformSchema(
schema: StructType,
@@ -46,23 +43,21 @@ private[spark] trait ClassifierParams extends PredictorParams
}
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
+ *
* Single-label binary or multiclass classification.
* Classes are indexed {0, 1, ..., numClasses - 1}.
*
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
* @tparam E Concrete Estimator type
* @tparam M Concrete Model type
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
-@AlphaComponent
-private[spark] abstract class Classifier[
+@DeveloperApi
+abstract class Classifier[
FeaturesType,
E <: Classifier[FeaturesType, E, M],
M <: ClassificationModel[FeaturesType, M]]
- extends Predictor[FeaturesType, E, M]
- with ClassifierParams {
+ extends Predictor[FeaturesType, E, M] with ClassifierParams {
/** @group setParam */
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
@@ -71,17 +66,15 @@ private[spark] abstract class Classifier[
}
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
+ *
* Model produced by a [[Classifier]].
* Classes are indexed {0, 1, ..., numClasses - 1}.
*
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
* @tparam M Concrete Model type
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
-@AlphaComponent
-private[spark]
+@DeveloperApi
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
extends PredictionModel[FeaturesType, M] with ClassifierParams {
@@ -101,13 +94,27 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* @return transformed dataset
*/
override def transform(dataset: DataFrame): DataFrame = {
- // This default implementation should be overridden as needed.
-
- // Check schema
transformSchema(dataset.schema, logging = true)
- val (numColsOutput, outputData) =
- ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this)
+ // Output selected columns only.
+ // This is a bit complicated since it tries to avoid repeated computation.
+ var outputData = dataset
+ var numColsOutput = 0
+ if (getRawPredictionCol != "") {
+ outputData = outputData.withColumn(getRawPredictionCol,
+ callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
+ numColsOutput += 1
+ }
+ if (getPredictionCol != "") {
+ val predUDF = if (getRawPredictionCol != "") {
+ callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol))
+ } else {
+ callUDF(predict _, DoubleType, col(getFeaturesCol))
+ }
+ outputData = outputData.withColumn(getPredictionCol, predUDF)
+ numColsOutput += 1
+ }
+
if (numColsOutput == 0) {
logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
@@ -116,22 +123,17 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
}
/**
- * :: DeveloperApi ::
- *
* Predict label for the given features.
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
*
* This default implementation for classification predicts the index of the maximum value
* from [[predictRaw()]].
*/
- @DeveloperApi
override protected def predict(features: FeaturesType): Double = {
- predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2
+ raw2prediction(predictRaw(features))
}
/**
- * :: DeveloperApi ::
- *
* Raw prediction for each possible label.
* The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
* a measure of confidence in each possible label (where larger = more confident).
@@ -141,48 +143,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* This raw prediction may be any real number, where a larger value indicates greater
* confidence for that label.
*/
- @DeveloperApi
protected def predictRaw(features: FeaturesType): Vector
-}
-
-private[ml] object ClassificationModel {
/**
- * Added prediction column(s). This is separated from [[ClassificationModel.transform()]]
- * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]].
- * @param dataset Input dataset
- * @return (number of columns added, transformed dataset)
+ * Given a vector of raw predictions, select the predicted label.
+ * This may be overridden to support thresholds which favor particular labels.
+ * @return predicted label
*/
- def transformColumnsImpl[FeaturesType](
- dataset: DataFrame,
- model: ClassificationModel[FeaturesType, _]): (Int, DataFrame) = {
-
- // Output selected columns only.
- // This is a bit complicated since it tries to avoid repeated computation.
- var tmpData = dataset
- var numColsOutput = 0
- if (model.getRawPredictionCol != "") {
- // output raw prediction
- val features2raw: FeaturesType => Vector = model.predictRaw
- tmpData = tmpData.withColumn(model.getRawPredictionCol,
- callUDF(features2raw, new VectorUDT, col(model.getFeaturesCol)))
- numColsOutput += 1
- if (model.getPredictionCol != "") {
- val raw2pred: Vector => Double = (rawPred) => {
- rawPred.toArray.zipWithIndex.maxBy(_._1)._2
- }
- tmpData = tmpData.withColumn(model.getPredictionCol,
- callUDF(raw2pred, DoubleType, col(model.getRawPredictionCol)))
- numColsOutput += 1
- }
- } else if (model.getPredictionCol != "") {
- // output prediction
- val features2pred: FeaturesType => Double = model.predict
- tmpData = tmpData.withColumn(model.getPredictionCol,
- callUDF(features2pred, DoubleType, col(model.getFeaturesCol)))
- numColsOutput += 1
- }
- (numColsOutput, tmpData)
- }
-
+ protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax
}
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 419e5ba..dcebea1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -18,10 +18,9 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 534ea95..ae51b05 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -21,11 +21,10 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index b73be03..550369d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -21,9 +21,8 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
+import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
/**
@@ -99,76 +98,17 @@ class LogisticRegressionModel private[ml] (
/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
+ /** Margin (rawPrediction) for class label 1. For binary classification only. */
private val margin: Vector => Double = (features) => {
BLAS.dot(features, weights) + intercept
}
+ /** Score (probability) for class label 1. For binary classification only. */
private val score: Vector => Double = (features) => {
val m = margin(features)
1.0 / (1.0 + math.exp(-m))
}
- override def transform(dataset: DataFrame): DataFrame = {
- // This is overridden (a) to be more efficient (avoiding re-computing values when creating
- // multiple output columns) and (b) to handle threshold, which the abstractions do not use.
- // TODO: We should abstract away the steps defined by UDFs below so that the abstractions
- // can call whichever UDFs are needed to create the output columns.
-
- // Check schema
- transformSchema(dataset.schema, logging = true)
-
- // Output selected columns only.
- // This is a bit complicated since it tries to avoid repeated computation.
- // rawPrediction (-margin, margin)
- // probability (1.0-score, score)
- // prediction (max margin)
- var tmpData = dataset
- var numColsOutput = 0
- if ($(rawPredictionCol) != "") {
- val features2raw: Vector => Vector = (features) => predictRaw(features)
- tmpData = tmpData.withColumn($(rawPredictionCol),
- callUDF(features2raw, new VectorUDT, col($(featuresCol))))
- numColsOutput += 1
- }
- if ($(probabilityCol) != "") {
- if ($(rawPredictionCol) != "") {
- val raw2prob = udf { (rawPreds: Vector) =>
- val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
- Vectors.dense(1.0 - prob1, prob1): Vector
- }
- tmpData = tmpData.withColumn($(probabilityCol), raw2prob(col($(rawPredictionCol))))
- } else {
- val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector }
- tmpData = tmpData.withColumn($(probabilityCol), features2prob(col($(featuresCol))))
- }
- numColsOutput += 1
- }
- if ($(predictionCol) != "") {
- val t = $(threshold)
- if ($(probabilityCol) != "") {
- val predict = udf { probs: Vector =>
- if (probs(1) > t) 1.0 else 0.0
- }
- tmpData = tmpData.withColumn($(predictionCol), predict(col($(probabilityCol))))
- } else if ($(rawPredictionCol) != "") {
- val predict = udf { rawPreds: Vector =>
- val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
- if (prob1 > t) 1.0 else 0.0
- }
- tmpData = tmpData.withColumn($(predictionCol), predict(col($(rawPredictionCol))))
- } else {
- val predict = udf { features: Vector => this.predict(features) }
- tmpData = tmpData.withColumn($(predictionCol), predict(col($(featuresCol))))
- }
- numColsOutput += 1
- }
- if (numColsOutput == 0) {
- this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" +
- " since no output columns were set.")
- }
- tmpData
- }
-
override val numClasses: Int = 2
/**
@@ -179,17 +119,43 @@ class LogisticRegressionModel private[ml] (
if (score(features) > getThreshold) 1 else 0
}
- override protected def predictProbabilities(features: Vector): Vector = {
- val s = score(features)
- Vectors.dense(1.0 - s, s)
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ var i = 0
+ while (i < dv.size) {
+ dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i)))
+ i += 1
+ }
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in LogisticRegressionModel:" +
+ " raw2probabilitiesInPlace encountered SparseVector")
+ }
}
override protected def predictRaw(features: Vector): Vector = {
val m = margin(features)
- Vectors.dense(0.0, m)
+ Vectors.dense(-m, m)
}
override def copy(extra: ParamMap): LogisticRegressionModel = {
copyValues(new LogisticRegressionModel(parent, weights, intercept), extra)
}
+
+ override protected def raw2prediction(rawPrediction: Vector): Double = {
+ val t = getThreshold
+ val rawThreshold = if (t == 0.0) {
+ Double.NegativeInfinity
+ } else if (t == 1.0) {
+ Double.PositiveInfinity
+ } else {
+ Math.log(t / (1.0 - t))
+ }
+ if (rawPrediction(1) > rawThreshold) 1 else 0
+ }
+
+ override protected def probability2prediction(probability: Vector): Double = {
+ if (probability(1) > getThreshold) 1 else 0
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 8519841..330ae29 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -17,16 +17,16 @@
package org.apache.spark.ml.classification
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
/**
- * Params for probabilistic classification.
+ * (private[classification]) Params for probabilistic classification.
*/
private[classification] trait ProbabilisticClassifierParams
extends ClassifierParams with HasProbabilityCol {
@@ -42,17 +42,15 @@ private[classification] trait ProbabilisticClassifierParams
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
*
* Single-label binary or multiclass classifier which can output class conditional probabilities.
*
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
* @tparam E Concrete Estimator type
* @tparam M Concrete Model type
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
-@AlphaComponent
+@DeveloperApi
private[spark] abstract class ProbabilisticClassifier[
FeaturesType,
E <: ProbabilisticClassifier[FeaturesType, E, M],
@@ -65,17 +63,15 @@ private[spark] abstract class ProbabilisticClassifier[
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
*
* Model produced by a [[ProbabilisticClassifier]].
* Classes are indexed {0, 1, ..., numClasses - 1}.
*
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
* @tparam M Concrete Model type
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
-@AlphaComponent
+@DeveloperApi
private[spark] abstract class ProbabilisticClassificationModel[
FeaturesType,
M <: ProbabilisticClassificationModel[FeaturesType, M]]
@@ -95,39 +91,79 @@ private[spark] abstract class ProbabilisticClassificationModel[
* @return transformed dataset
*/
override def transform(dataset: DataFrame): DataFrame = {
- // This default implementation should be overridden as needed.
-
- // Check schema
transformSchema(dataset.schema, logging = true)
- val (numColsOutput, outputData) =
- ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this)
-
// Output selected columns only.
- if ($(probabilityCol) != "") {
- // output probabilities
- outputData.withColumn($(probabilityCol),
- callUDF(predictProbabilities _, new VectorUDT, col($(featuresCol))))
- } else {
- if (numColsOutput == 0) {
- this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
- " since no output columns were set.")
+ // This is a bit complicated since it tries to avoid repeated computation.
+ var outputData = dataset
+ var numColsOutput = 0
+ if ($(rawPredictionCol).nonEmpty) {
+ outputData = outputData.withColumn(getRawPredictionCol,
+ callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
+ numColsOutput += 1
+ }
+ if ($(probabilityCol).nonEmpty) {
+ val probUDF = if ($(rawPredictionCol).nonEmpty) {
+ callUDF(raw2probability _, new VectorUDT, col($(rawPredictionCol)))
+ } else {
+ callUDF(predictProbability _, new VectorUDT, col($(featuresCol)))
+ }
+ outputData = outputData.withColumn($(probabilityCol), probUDF)
+ numColsOutput += 1
+ }
+ if ($(predictionCol).nonEmpty) {
+ val predUDF = if ($(rawPredictionCol).nonEmpty) {
+ callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol)))
+ } else if ($(probabilityCol).nonEmpty) {
+ callUDF(probability2prediction _, DoubleType, col($(probabilityCol)))
+ } else {
+ callUDF(predict _, DoubleType, col($(featuresCol)))
}
- outputData
+ outputData = outputData.withColumn($(predictionCol), predUDF)
+ numColsOutput += 1
+ }
+
+ if (numColsOutput == 0) {
+ this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
+ " since no output columns were set.")
}
+ outputData
}
/**
- * :: DeveloperApi ::
+ * Estimate the probability of each class given the raw prediction,
+ * doing the computation in-place.
+ * These predictions are also called class conditional probabilities.
+ *
+ * This internal method is used to implement [[transform()]] and output [[probabilityCol]].
*
+ * @return Estimated class conditional probabilities (modified input vector)
+ */
+ protected def raw2probabilityInPlace(rawPrediction: Vector): Vector
+
+ /** Non-in-place version of [[raw2probabilityInPlace()]] */
+ protected def raw2probability(rawPrediction: Vector): Vector = {
+ val probs = rawPrediction.copy
+ raw2probabilityInPlace(probs)
+ }
+
+ /**
* Predict the probability of each class given the features.
* These predictions are also called class conditional probabilities.
*
- * WARNING: Not all models output well-calibrated probability estimates! These probabilities
- * should be treated as confidences, not precise probabilities.
- *
* This internal method is used to implement [[transform()]] and output [[probabilityCol]].
+ *
+ * @return Estimated class conditional probabilities
+ */
+ protected def predictProbability(features: FeaturesType): Vector = {
+ val rawPreds = predictRaw(features)
+ raw2probabilityInPlace(rawPreds)
+ }
+
+ /**
+ * Given a vector of class conditional probabilities, select the predicted label.
+ * This may be overridden to support thresholds which favor particular labels.
+ * @return predicted label
*/
- @DeveloperApi
- protected def predictProbabilities(features: FeaturesType): Vector
+ protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax
}
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 17f59bb..9954893 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -20,10 +20,9 @@ package org.apache.spark.ml.classification
import scala.collection.mutable
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
deleted file mode 100644
index e8b3628..0000000
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ /dev/null
@@ -1,217 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.impl.estimator
-
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
-
-/**
- * :: DeveloperApi ::
- *
- * Trait for parameters for prediction (regression and classification).
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
- */
-@DeveloperApi
-private[spark] trait PredictorParams extends Params
- with HasLabelCol with HasFeaturesCol with HasPredictionCol {
-
- /**
- * Validates and transforms the input schema with the provided param map.
- * @param schema input schema
- * @param fitting whether this is in fitting
- * @param featuresDataType SQL DataType for FeaturesType.
- * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
- * @return output schema
- */
- protected def validateAndTransformSchema(
- schema: StructType,
- fitting: Boolean,
- featuresDataType: DataType): StructType = {
- // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
- SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
- if (fitting) {
- // TODO: Allow other numeric types
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
- }
- SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
- }
-}
-
-/**
- * :: AlphaComponent ::
- *
- * Abstraction for prediction problems (regression and classification).
- *
- * @tparam FeaturesType Type of features.
- * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
- * @tparam Learner Specialization of this class. If you subclass this type, use this type
- * parameter to specify the concrete type.
- * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
- * parameter to specify the concrete type for the corresponding model.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
- */
-@AlphaComponent
-private[spark] abstract class Predictor[
- FeaturesType,
- Learner <: Predictor[FeaturesType, Learner, M],
- M <: PredictionModel[FeaturesType, M]]
- extends Estimator[M] with PredictorParams {
-
- /** @group setParam */
- def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
-
- /** @group setParam */
- def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
-
- /** @group setParam */
- def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
-
- override def fit(dataset: DataFrame): M = {
- // This handles a few items such as schema validation.
- // Developers only need to implement train().
- transformSchema(dataset.schema, logging = true)
- copyValues(train(dataset))
- }
-
- override def copy(extra: ParamMap): Learner = {
- super.copy(extra).asInstanceOf[Learner]
- }
-
- /**
- * :: DeveloperApi ::
- *
- * Train a model using the given dataset and parameters.
- * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
- * and copying parameters into the model.
- *
- * @param dataset Training dataset
- * @return Fitted model
- */
- @DeveloperApi
- protected def train(dataset: DataFrame): M
-
- /**
- * :: DeveloperApi ::
- *
- * Returns the SQL DataType corresponding to the FeaturesType type parameter.
- *
- * This is used by [[validateAndTransformSchema()]].
- * This workaround is needed since SQL has different APIs for Scala and Java.
- *
- * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
- */
- @DeveloperApi
- protected def featuresDataType: DataType = new VectorUDT
-
- override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema, fitting = true, featuresDataType)
- }
-
- /**
- * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
- * and put it in an RDD with strong types.
- */
- protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
- dataset.select($(labelCol), $(featuresCol))
- .map { case Row(label: Double, features: Vector) =>
- LabeledPoint(label, features)
- }
- }
-}
-
-/**
- * :: AlphaComponent ::
- *
- * Abstraction for a model for prediction tasks (regression and classification).
- *
- * @tparam FeaturesType Type of features.
- * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
- * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
- * parameter to specify the concrete type for the corresponding model.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
- */
-@AlphaComponent
-private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
- extends Model[M] with PredictorParams {
-
- /** @group setParam */
- def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
-
- /** @group setParam */
- def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
-
- /**
- * :: DeveloperApi ::
- *
- * Returns the SQL DataType corresponding to the FeaturesType type parameter.
- *
- * This is used by [[validateAndTransformSchema()]].
- * This workaround is needed since SQL has different APIs for Scala and Java.
- *
- * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
- */
- @DeveloperApi
- protected def featuresDataType: DataType = new VectorUDT
-
- override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema, fitting = false, featuresDataType)
- }
-
- /**
- * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing
- * the predictions as a new column [[predictionCol]].
- *
- * @param dataset input dataset
- * @return transformed dataset with [[predictionCol]] of type [[Double]]
- */
- override def transform(dataset: DataFrame): DataFrame = {
- // This default implementation should be overridden as needed.
-
- // Check schema
- transformSchema(dataset.schema, logging = true)
-
- if ($(predictionCol) != "") {
- dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
- } else {
- this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
- " since no output columns were set.")
- dataset
- }
- }
-
- /**
- * :: DeveloperApi ::
- *
- * Predict label for the given features.
- * This internal method is used to implement [[transform()]] and output [[predictionCol]].
- */
- @DeveloperApi
- protected def predict(features: FeaturesType): Double
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
deleted file mode 100644
index 0e22562..0000000
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ /dev/null
@@ -1,431 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.impl.tree
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.ml.impl.estimator.PredictorParams
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
-import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
-
-/**
- * :: DeveloperApi ::
- * Parameters for Decision Tree-based algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-@DeveloperApi
-private[ml] trait DecisionTreeParams extends PredictorParams {
-
- /**
- * Maximum depth of the tree (>= 0).
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * (default = 5)
- * @group param
- */
- final val maxDepth: IntParam =
- new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" +
- " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
- ParamValidators.gtEq(0))
-
- /**
- * Maximum number of bins used for discretizing continuous features and for choosing how to split
- * on features at each node. More bins give higher granularity.
- * Must be >= 2 and >= number of categories in any categorical feature.
- * (default = 32)
- * @group param
- */
- final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
- " discretizing continuous features. Must be >=2 and >= number of categories for any" +
- " categorical feature.", ParamValidators.gtEq(2))
-
- /**
- * Minimum number of instances each child must have after split.
- * If a split causes the left or right child to have fewer than minInstancesPerNode,
- * the split will be discarded as invalid.
- * Should be >= 1.
- * (default = 1)
- * @group param
- */
- final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
- " number of instances each child must have after split. If a split causes the left or right" +
- " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
- " Should be >= 1.", ParamValidators.gtEq(1))
-
- /**
- * Minimum information gain for a split to be considered at a tree node.
- * (default = 0.0)
- * @group param
- */
- final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
- "Minimum information gain for a split to be considered at a tree node.")
-
- /**
- * Maximum memory in MB allocated to histogram aggregation.
- * (default = 256 MB)
- * @group expertParam
- */
- final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
- "Maximum memory in MB allocated to histogram aggregation.",
- ParamValidators.gtEq(0))
-
- /**
- * If false, the algorithm will pass trees to executors to match instances with nodes.
- * If true, the algorithm will cache node IDs for each instance.
- * Caching can speed up training of deeper trees.
- * (default = false)
- * @group expertParam
- */
- final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
- " algorithm will pass trees to executors to match instances with nodes. If true, the" +
- " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
- " trees.")
-
- /**
- * Specifies how often to checkpoint the cached node IDs.
- * E.g. 10 means that the cache will get checkpointed every 10 iterations.
- * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
- * [[org.apache.spark.SparkContext]].
- * Must be >= 1.
- * (default = 10)
- * @group expertParam
- */
- final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
- " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" +
- " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
- " checkpoint directory is set in the SparkContext. Must be >= 1.",
- ParamValidators.gtEq(1))
-
- setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
- maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
-
- /** @group setParam */
- def setMaxDepth(value: Int): this.type = set(maxDepth, value)
-
- /** @group getParam */
- final def getMaxDepth: Int = $(maxDepth)
-
- /** @group setParam */
- def setMaxBins(value: Int): this.type = set(maxBins, value)
-
- /** @group getParam */
- final def getMaxBins: Int = $(maxBins)
-
- /** @group setParam */
- def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
-
- /** @group getParam */
- final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
-
- /** @group setParam */
- def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
-
- /** @group getParam */
- final def getMinInfoGain: Double = $(minInfoGain)
-
- /** @group expertSetParam */
- def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
-
- /** @group expertGetParam */
- final def getMaxMemoryInMB: Int = $(maxMemoryInMB)
-
- /** @group expertSetParam */
- def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
-
- /** @group expertGetParam */
- final def getCacheNodeIds: Boolean = $(cacheNodeIds)
-
- /** @group expertSetParam */
- def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
-
- /** @group expertGetParam */
- final def getCheckpointInterval: Int = $(checkpointInterval)
-
- /** (private[ml]) Create a Strategy instance to use with the old API. */
- private[ml] def getOldStrategy(
- categoricalFeatures: Map[Int, Int],
- numClasses: Int,
- oldAlgo: OldAlgo.Algo,
- oldImpurity: OldImpurity,
- subsamplingRate: Double): OldStrategy = {
- val strategy = OldStrategy.defaultStategy(oldAlgo)
- strategy.impurity = oldImpurity
- strategy.checkpointInterval = getCheckpointInterval
- strategy.maxBins = getMaxBins
- strategy.maxDepth = getMaxDepth
- strategy.maxMemoryInMB = getMaxMemoryInMB
- strategy.minInfoGain = getMinInfoGain
- strategy.minInstancesPerNode = getMinInstancesPerNode
- strategy.useNodeIdCache = getCacheNodeIds
- strategy.numClasses = numClasses
- strategy.categoricalFeaturesInfo = categoricalFeatures
- strategy.subsamplingRate = subsamplingRate
- strategy
- }
-}
-
-/**
- * Parameters for Decision Tree-based classification algorithms.
- */
-private[ml] trait TreeClassifierParams extends Params {
-
- /**
- * Criterion used for information gain calculation (case-insensitive).
- * Supported: "entropy" and "gini".
- * (default = gini)
- * @group param
- */
- final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
- " information gain calculation (case-insensitive). Supported options:" +
- s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
- (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase))
-
- setDefault(impurity -> "gini")
-
- /** @group setParam */
- def setImpurity(value: String): this.type = set(impurity, value)
-
- /** @group getParam */
- final def getImpurity: String = $(impurity).toLowerCase
-
- /** Convert new impurity to old impurity. */
- private[ml] def getOldImpurity: OldImpurity = {
- getImpurity match {
- case "entropy" => OldEntropy
- case "gini" => OldGini
- case _ =>
- // Should never happen because of check in setter method.
- throw new RuntimeException(
- s"TreeClassifierParams was given unrecognized impurity: $impurity.")
- }
- }
-}
-
-private[ml] object TreeClassifierParams {
- // These options should be lowercase.
- final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
-}
-
-/**
- * Parameters for Decision Tree-based regression algorithms.
- */
-private[ml] trait TreeRegressorParams extends Params {
-
- /**
- * Criterion used for information gain calculation (case-insensitive).
- * Supported: "variance".
- * (default = variance)
- * @group param
- */
- final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
- " information gain calculation (case-insensitive). Supported options:" +
- s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
- (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase))
-
- setDefault(impurity -> "variance")
-
- /** @group setParam */
- def setImpurity(value: String): this.type = set(impurity, value)
-
- /** @group getParam */
- final def getImpurity: String = $(impurity).toLowerCase
-
- /** Convert new impurity to old impurity. */
- private[ml] def getOldImpurity: OldImpurity = {
- getImpurity match {
- case "variance" => OldVariance
- case _ =>
- // Should never happen because of check in setter method.
- throw new RuntimeException(
- s"TreeRegressorParams was given unrecognized impurity: $impurity")
- }
- }
-}
-
-private[ml] object TreeRegressorParams {
- // These options should be lowercase.
- final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
-}
-
-/**
- * :: DeveloperApi ::
- * Parameters for Decision Tree-based ensemble algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-@DeveloperApi
-private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
-
- /**
- * Fraction of the training data used for learning each decision tree, in range (0, 1].
- * (default = 1.0)
- * @group param
- */
- final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
- "Fraction of the training data used for learning each decision tree, in range (0, 1].",
- ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
-
- setDefault(subsamplingRate -> 1.0)
-
- /** @group setParam */
- def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
-
- /** @group getParam */
- final def getSubsamplingRate: Double = $(subsamplingRate)
-
- /** @group setParam */
- def setSeed(value: Long): this.type = set(seed, value)
-
- /**
- * Create a Strategy instance to use with the old API.
- * NOTE: The caller should set impurity and seed.
- */
- private[ml] def getOldStrategy(
- categoricalFeatures: Map[Int, Int],
- numClasses: Int,
- oldAlgo: OldAlgo.Algo,
- oldImpurity: OldImpurity): OldStrategy = {
- super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
- }
-}
-
-/**
- * :: DeveloperApi ::
- * Parameters for Random Forest algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-@DeveloperApi
-private[ml] trait RandomForestParams extends TreeEnsembleParams {
-
- /**
- * Number of trees to train (>= 1).
- * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
- * TODO: Change to always do bootstrapping (simpler). SPARK-7130
- * (default = 20)
- * @group param
- */
- final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
- ParamValidators.gtEq(1))
-
- /**
- * The number of features to consider for splits at each tree node.
- * Supported options:
- * - "auto": Choose automatically for task:
- * If numTrees == 1, set to "all."
- * If numTrees > 1 (forest), set to "sqrt" for classification and
- * to "onethird" for regression.
- * - "all": use all features
- * - "onethird": use 1/3 of the features
- * - "sqrt": use sqrt(number of features)
- * - "log2": use log2(number of features)
- * (default = "auto")
- *
- * These various settings are based on the following references:
- * - log2: tested in Breiman (2001)
- * - sqrt: recommended by Breiman manual for random forests
- * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
- * package.
- * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]]
- * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for
- * random forests]]
- *
- * @group param
- */
- final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
- "The number of features to consider for splits at each tree node." +
- s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
- (value: String) =>
- RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
-
- setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
-
- /** @group setParam */
- def setNumTrees(value: Int): this.type = set(numTrees, value)
-
- /** @group getParam */
- final def getNumTrees: Int = $(numTrees)
-
- /** @group setParam */
- def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
-
- /** @group getParam */
- final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
-}
-
-private[ml] object RandomForestParams {
- // These options should be lowercase.
- final val supportedFeatureSubsetStrategies: Array[String] =
- Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
-}
-
-/**
- * :: DeveloperApi ::
- * Parameters for Gradient-Boosted Tree algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-@DeveloperApi
-private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
-
- /**
- * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
- * estimator.
- * (default = 0.1)
- * @group param
- */
- final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
- " learning rate) in interval (0, 1] for shrinking the contribution of each estimator",
- ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
-
- /* TODO: Add this doc when we add this param. SPARK-7132
- * Threshold for stopping early when runWithValidation is used.
- * If the error rate on the validation input changes by less than the validationTol,
- * then learning will stop early (before [[numIterations]]).
- * This parameter is ignored when run is used.
- * (default = 1e-5)
- * @group param
- */
- // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
- // validationTol -> 1e-5
-
- setDefault(maxIter -> 20, stepSize -> 0.1)
-
- /** @group setParam */
- def setMaxIter(value: Int): this.type = set(maxIter, value)
-
- /** @group setParam */
- def setStepSize(value: Double): this.type = set(stepSize, value)
-
- /** @group getParam */
- final def getStepSize: Double = $(stepSize)
-
- /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
- private[ml] def getOldBoostingStrategy(
- categoricalFeatures: Map[Int, Int],
- oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
- val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
- // NOTE: The old API does not support "seed" so we ignore it.
- new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
- }
-
- /** Get old Gradient Boosting Loss type */
- private[ml] def getOldLossType: OldLoss
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index d379172..0e1ff97 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -40,8 +40,10 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")),
ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
Some("\"rawPrediction\"")),
- ParamDesc[String]("probabilityCol",
- "column name for predicted class conditional probabilities", Some("\"probability\"")),
+ ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" +
+ " probabilities. Note: Not all models output well-calibrated probability estimates!" +
+ " These probabilities should be treated as confidences, not precise probabilities.",
+ Some("\"probability\"")),
ParamDesc[Double]("threshold",
"threshold in binary classification prediction, in range [0, 1]",
isValid = "ParamValidators.inRange(0, 1)"),
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index fb1874c..87f8680 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -128,10 +128,10 @@ private[ml] trait HasRawPredictionCol extends Params {
private[ml] trait HasProbabilityCol extends Params {
/**
- * Param for column name for predicted class conditional probabilities.
+ * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities..
* @group param
*/
- final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities")
+ final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
setDefault(probabilityCol, "probability")
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index b07c26f..f8f0b16 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -18,10 +18,9 @@
package org.apache.spark.ml.regression
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index bc79695..461905c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -21,10 +21,9 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 66c475f..e63c9a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -25,6 +25,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -39,7 +40,7 @@ import org.apache.spark.util.StatCounter
/**
* Params for linear regression.
*/
-private[regression] trait LinearRegressionParams extends RegressorParams
+private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
/**
@@ -240,7 +241,7 @@ class LinearRegressionModel private[ml] (
* + \bar{y} / \hat{y}||^2
* = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2
* }}}
- * where w_i^\prime is the effective weights defined by w_i/\hat{x_i}, offset is
+ * where w_i^\prime^ is the effective weights defined by w_i/\hat{x_i}, offset is
* {{{
* - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}.
* }}}, and diff is
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 0468a1b..dbc6289 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -18,10 +18,9 @@
package org.apache.spark.ml.regression
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams}
+import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
index c6b3327..c72ef29 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
@@ -17,62 +17,40 @@
package org.apache.spark.ml.regression
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
-/**
- * :: DeveloperApi ::
- * Params for regression.
- * Currently empty, but may add functionality later.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
- */
-@DeveloperApi
-private[spark] trait RegressorParams extends PredictorParams
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
*
* Single-label regression
*
* @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]]
* @tparam Learner Concrete Estimator type
* @tparam M Concrete Model type
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
-@AlphaComponent
+@DeveloperApi
private[spark] abstract class Regressor[
FeaturesType,
Learner <: Regressor[FeaturesType, Learner, M],
M <: RegressionModel[FeaturesType, M]]
- extends Predictor[FeaturesType, Learner, M]
- with RegressorParams {
+ extends Predictor[FeaturesType, Learner, M] with PredictorParams {
// TODO: defaultEvaluator (follow-up PR)
}
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
*
* Model produced by a [[Regressor]].
*
* @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]]
* @tparam M Concrete Model type.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
-@AlphaComponent
-private[spark] abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]]
- extends PredictionModel[FeaturesType, M] with RegressorParams {
-
- /**
- * :: DeveloperApi ::
- *
- * Predict real-valued label for the given features.
- * This internal method is used to implement [[transform()]] and output [[predictionCol]].
- */
- @DeveloperApi
- protected def predict(features: FeaturesType): Double
+@DeveloperApi
+abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]]
+ extends PredictionModel[FeaturesType, M] with PredictorParams {
+ // TODO: defaultEvaluator (follow-up PR)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
new file mode 100644
index 0000000..816fced
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -0,0 +1,431 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
+import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait DecisionTreeParams extends PredictorParams {
+
+ /**
+ * Maximum depth of the tree (>= 0).
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * (default = 5)
+ * @group param
+ */
+ final val maxDepth: IntParam =
+ new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" +
+ " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
+ ParamValidators.gtEq(0))
+
+ /**
+ * Maximum number of bins used for discretizing continuous features and for choosing how to split
+ * on features at each node. More bins give higher granularity.
+ * Must be >= 2 and >= number of categories in any categorical feature.
+ * (default = 32)
+ * @group param
+ */
+ final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
+ " discretizing continuous features. Must be >=2 and >= number of categories for any" +
+ " categorical feature.", ParamValidators.gtEq(2))
+
+ /**
+ * Minimum number of instances each child must have after split.
+ * If a split causes the left or right child to have fewer than minInstancesPerNode,
+ * the split will be discarded as invalid.
+ * Should be >= 1.
+ * (default = 1)
+ * @group param
+ */
+ final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
+ " number of instances each child must have after split. If a split causes the left or right" +
+ " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
+ " Should be >= 1.", ParamValidators.gtEq(1))
+
+ /**
+ * Minimum information gain for a split to be considered at a tree node.
+ * (default = 0.0)
+ * @group param
+ */
+ final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
+ "Minimum information gain for a split to be considered at a tree node.")
+
+ /**
+ * Maximum memory in MB allocated to histogram aggregation.
+ * (default = 256 MB)
+ * @group expertParam
+ */
+ final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
+ "Maximum memory in MB allocated to histogram aggregation.",
+ ParamValidators.gtEq(0))
+
+ /**
+ * If false, the algorithm will pass trees to executors to match instances with nodes.
+ * If true, the algorithm will cache node IDs for each instance.
+ * Caching can speed up training of deeper trees.
+ * (default = false)
+ * @group expertParam
+ */
+ final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
+ " algorithm will pass trees to executors to match instances with nodes. If true, the" +
+ " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
+ " trees.")
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be >= 1.
+ * (default = 10)
+ * @group expertParam
+ */
+ final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
+ " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" +
+ " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
+ " checkpoint directory is set in the SparkContext. Must be >= 1.",
+ ParamValidators.gtEq(1))
+
+ setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
+ maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+
+ /** @group setParam */
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+
+ /** @group getParam */
+ final def getMaxDepth: Int = $(maxDepth)
+
+ /** @group setParam */
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
+
+ /** @group getParam */
+ final def getMaxBins: Int = $(maxBins)
+
+ /** @group setParam */
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+
+ /** @group getParam */
+ final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
+
+ /** @group setParam */
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+
+ /** @group getParam */
+ final def getMinInfoGain: Double = $(minInfoGain)
+
+ /** @group expertSetParam */
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+
+ /** @group expertGetParam */
+ final def getMaxMemoryInMB: Int = $(maxMemoryInMB)
+
+ /** @group expertSetParam */
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+
+ /** @group expertGetParam */
+ final def getCacheNodeIds: Boolean = $(cacheNodeIds)
+
+ /** @group expertSetParam */
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+ /** @group expertGetParam */
+ final def getCheckpointInterval: Int = $(checkpointInterval)
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ oldAlgo: OldAlgo.Algo,
+ oldImpurity: OldImpurity,
+ subsamplingRate: Double): OldStrategy = {
+ val strategy = OldStrategy.defaultStategy(oldAlgo)
+ strategy.impurity = oldImpurity
+ strategy.checkpointInterval = getCheckpointInterval
+ strategy.maxBins = getMaxBins
+ strategy.maxDepth = getMaxDepth
+ strategy.maxMemoryInMB = getMaxMemoryInMB
+ strategy.minInfoGain = getMinInfoGain
+ strategy.minInstancesPerNode = getMinInstancesPerNode
+ strategy.useNodeIdCache = getCacheNodeIds
+ strategy.numClasses = numClasses
+ strategy.categoricalFeaturesInfo = categoricalFeatures
+ strategy.subsamplingRate = subsamplingRate
+ strategy
+ }
+}
+
+/**
+ * Parameters for Decision Tree-based classification algorithms.
+ */
+private[ml] trait TreeClassifierParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * Supported: "entropy" and "gini".
+ * (default = gini)
+ * @group param
+ */
+ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
+ (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase))
+
+ setDefault(impurity -> "gini")
+
+ /** @group setParam */
+ def setImpurity(value: String): this.type = set(impurity, value)
+
+ /** @group getParam */
+ final def getImpurity: String = $(impurity).toLowerCase
+
+ /** Convert new impurity to old impurity. */
+ private[ml] def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "entropy" => OldEntropy
+ case "gini" => OldGini
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeClassifierParams was given unrecognized impurity: $impurity.")
+ }
+ }
+}
+
+private[ml] object TreeClassifierParams {
+ // These options should be lowercase.
+ final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+}
+
+/**
+ * Parameters for Decision Tree-based regression algorithms.
+ */
+private[ml] trait TreeRegressorParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * Supported: "variance".
+ * (default = variance)
+ * @group param
+ */
+ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
+ (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase))
+
+ setDefault(impurity -> "variance")
+
+ /** @group setParam */
+ def setImpurity(value: String): this.type = set(impurity, value)
+
+ /** @group getParam */
+ final def getImpurity: String = $(impurity).toLowerCase
+
+ /** Convert new impurity to old impurity. */
+ private[ml] def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "variance" => OldVariance
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeRegressorParams was given unrecognized impurity: $impurity")
+ }
+ }
+}
+
+private[ml] object TreeRegressorParams {
+ // These options should be lowercase.
+ final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based ensemble algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
+
+ /**
+ * Fraction of the training data used for learning each decision tree, in range (0, 1].
+ * (default = 1.0)
+ * @group param
+ */
+ final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
+ "Fraction of the training data used for learning each decision tree, in range (0, 1].",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+ setDefault(subsamplingRate -> 1.0)
+
+ /** @group setParam */
+ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
+
+ /** @group getParam */
+ final def getSubsamplingRate: Double = $(subsamplingRate)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /**
+ * Create a Strategy instance to use with the old API.
+ * NOTE: The caller should set impurity and seed.
+ */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ oldAlgo: OldAlgo.Algo,
+ oldImpurity: OldImpurity): OldStrategy = {
+ super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Random Forest algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait RandomForestParams extends TreeEnsembleParams {
+
+ /**
+ * Number of trees to train (>= 1).
+ * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ * @group param
+ */
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+ ParamValidators.gtEq(1))
+
+ /**
+ * The number of features to consider for splits at each tree node.
+ * Supported options:
+ * - "auto": Choose automatically for task:
+ * If numTrees == 1, set to "all."
+ * If numTrees > 1 (forest), set to "sqrt" for classification and
+ * to "onethird" for regression.
+ * - "all": use all features
+ * - "onethird": use 1/3 of the features
+ * - "sqrt": use sqrt(number of features)
+ * - "log2": use log2(number of features)
+ * (default = "auto")
+ *
+ * These various settings are based on the following references:
+ * - log2: tested in Breiman (2001)
+ * - sqrt: recommended by Breiman manual for random forests
+ * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
+ * package.
+ * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]]
+ * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for
+ * random forests]]
+ *
+ * @group param
+ */
+ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
+ "The number of features to consider for splits at each tree node." +
+ s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
+ (value: String) =>
+ RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
+
+ setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
+
+ /** @group setParam */
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
+
+ /** @group getParam */
+ final def getNumTrees: Int = $(numTrees)
+
+ /** @group setParam */
+ def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
+
+ /** @group getParam */
+ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+}
+
+private[ml] object RandomForestParams {
+ // These options should be lowercase.
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Gradient-Boosted Tree algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
+
+ /**
+ * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
+ * estimator.
+ * (default = 0.1)
+ * @group param
+ */
+ final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
+ " learning rate) in interval (0, 1] for shrinking the contribution of each estimator",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+ /* TODO: Add this doc when we add this param. SPARK-7132
+ * Threshold for stopping early when runWithValidation is used.
+ * If the error rate on the validation input changes by less than the validationTol,
+ * then learning will stop early (before [[numIterations]]).
+ * This parameter is ignored when run is used.
+ * (default = 1e-5)
+ * @group param
+ */
+ // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
+ // validationTol -> 1e-5
+
+ setDefault(maxIter -> 20, stepSize -> 0.1)
+
+ /** @group setParam */
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ def setStepSize(value: Double): this.type = set(stepSize, value)
+
+ /** @group getParam */
+ final def getStepSize: Double = $(stepSize)
+
+ /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
+ private[ml] def getOldBoostingStrategy(
+ categoricalFeatures: Map[Int, Int],
+ oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
+ // NOTE: The old API does not support "seed" so we ignore it.
+ new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
+ }
+
+ /** Get old Gradient Boosting Loss type */
+ private[ml] def getOldLossType: OldLoss
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org