You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/20 09:49:07 UTC
[2/2] spark git commit: [SPARK-4486][MLLIB] Improve GradientBoosting
APIs and doc
[SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc
There are some inconsistencies in the gradient boosting APIs. The target is a general boosting meta-algorithm, but the implementation is attached to trees. This was partially due to the delay of SPARK-1856. But for the 1.2 release, we should make the APIs consistent.
1. WeightedEnsembleModel -> private[tree] TreeEnsembleModel and renamed members accordingly.
1. GradientBoosting -> GradientBoostedTrees
1. Add RandomForestModel and GradientBoostedTreesModel and hide CombiningStrategy
1. Slightly refactored TreeEnsembleModel (Vote takes weights into consideration.)
1. Remove `trainClassifier` and `trainRegressor` from `GradientBoostedTrees` because they are the same as `train`
1. Rename class `train` method to `run` because it hides the static methods with the same name in Java. Deprecated `DecisionTree.train` class method.
1. Simplify BoostingStrategy and make sure the input strategy is not modified. Users should put algo and numClasses in treeStrategy. We create ensembleStrategy inside boosting.
1. Fix a bug in GradientBoostedTreesSuite with AbsoluteError
1. doc updates
manishamde jkbradley
Author: Xiangrui Meng <me...@databricks.com>
Closes #3374 from mengxr/SPARK-4486 and squashes the following commits:
7097251 [Xiangrui Meng] address joseph's comments
98dea09 [Xiangrui Meng] address manish's comments
4aae3b7 [Xiangrui Meng] add RandomForestModel and GradientBoostedTreesModel, hide CombiningStrategy
ea4c467 [Xiangrui Meng] fix unit tests
751da4e [Xiangrui Meng] rename class method train -> run
19030a5 [Xiangrui Meng] update boosting public APIs
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/15cacc81
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/15cacc81
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/15cacc81
Branch: refs/heads/master
Commit: 15cacc81240eed8834b4730c5c6dc3238f003465
Parents: e216ffa
Author: Xiangrui Meng <me...@databricks.com>
Authored: Thu Nov 20 00:48:59 2014 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Nov 20 00:48:59 2014 -0800
----------------------------------------------------------------------
.../mllib/JavaGradientBoostedTrees.java | 126 ----------
.../mllib/JavaGradientBoostedTreesRunner.java | 126 ++++++++++
.../examples/mllib/DecisionTreeRunner.scala | 18 +-
.../examples/mllib/GradientBoostedTrees.scala | 146 -----------
.../mllib/GradientBoostedTreesRunner.scala | 146 +++++++++++
.../apache/spark/mllib/tree/DecisionTree.scala | 20 +-
.../spark/mllib/tree/GradientBoostedTrees.scala | 192 ++++++++++++++
.../spark/mllib/tree/GradientBoosting.scala | 249 -------------------
.../apache/spark/mllib/tree/RandomForest.scala | 40 ++-
.../tree/configuration/BoostingStrategy.scala | 50 ++--
.../EnsembleCombiningStrategy.scala | 8 +-
.../mllib/tree/configuration/Strategy.scala | 7 +
.../spark/mllib/tree/loss/AbsoluteError.scala | 6 +-
.../apache/spark/mllib/tree/loss/LogLoss.scala | 6 +-
.../org/apache/spark/mllib/tree/loss/Loss.scala | 6 +-
.../spark/mllib/tree/loss/SquaredError.scala | 6 +-
.../mllib/tree/model/DecisionTreeModel.scala | 4 +-
.../tree/model/WeightedEnsembleModel.scala | 158 ------------
.../mllib/tree/model/treeEnsembleModels.scala | 178 +++++++++++++
.../spark/mllib/tree/JavaDecisionTreeSuite.java | 2 +-
.../spark/mllib/tree/EnsembleTestHelper.scala | 30 ++-
.../mllib/tree/GradientBoostedTreesSuite.scala | 117 +++++++++
.../mllib/tree/GradientBoostingSuite.scala | 126 ----------
.../spark/mllib/tree/RandomForestSuite.scala | 14 +-
24 files changed, 863 insertions(+), 918 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
deleted file mode 100644
index 1af2067..0000000
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.mllib;
-
-import scala.Tuple2;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.tree.GradientBoosting;
-import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
-import org.apache.spark.mllib.util.MLUtils;
-
-/**
- * Classification and regression using gradient-boosted decision trees.
- */
-public final class JavaGradientBoostedTrees {
-
- private static void usage() {
- System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
- " <Classification/Regression>");
- System.exit(-1);
- }
-
- public static void main(String[] args) {
- String datapath = "data/mllib/sample_libsvm_data.txt";
- String algo = "Classification";
- if (args.length >= 1) {
- datapath = args[0];
- }
- if (args.length >= 2) {
- algo = args[1];
- }
- if (args.length > 2) {
- usage();
- }
- SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
- JavaSparkContext sc = new JavaSparkContext(sparkConf);
-
- JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
-
- // Set parameters.
- // Note: All features are treated as continuous.
- BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
- boostingStrategy.setNumIterations(10);
- boostingStrategy.weakLearnerParams().setMaxDepth(5);
-
- if (algo.equals("Classification")) {
- // Compute the number of classes from the data.
- Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
- @Override public Double call(LabeledPoint p) {
- return p.label();
- }
- }).countByValue().size();
- boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
-
- // Train a GradientBoosting model for classification.
- final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
-
- // Evaluate model on training instances and compute training error
- JavaPairRDD<Double, Double> predictionAndLabel =
- data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
- @Override public Tuple2<Double, Double> call(LabeledPoint p) {
- return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
- }
- });
- Double trainErr =
- 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
- @Override public Boolean call(Tuple2<Double, Double> pl) {
- return !pl._1().equals(pl._2());
- }
- }).count() / data.count();
- System.out.println("Training error: " + trainErr);
- System.out.println("Learned classification tree model:\n" + model);
- } else if (algo.equals("Regression")) {
- // Train a GradientBoosting model for classification.
- final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
-
- // Evaluate model on training instances and compute training error
- JavaPairRDD<Double, Double> predictionAndLabel =
- data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
- @Override public Tuple2<Double, Double> call(LabeledPoint p) {
- return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
- }
- });
- Double trainMSE =
- predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
- @Override public Double call(Tuple2<Double, Double> pl) {
- Double diff = pl._1() - pl._2();
- return diff * diff;
- }
- }).reduce(new Function2<Double, Double, Double>() {
- @Override public Double call(Double a, Double b) {
- return a + b;
- }
- }) / data.count();
- System.out.println("Training Mean Squared Error: " + trainMSE);
- System.out.println("Learned regression tree model:\n" + model);
- } else {
- usage();
- }
-
- sc.stop();
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
new file mode 100644
index 0000000..4a5ac40
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple2;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.GradientBoostedTrees;
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
+import org.apache.spark.mllib.util.MLUtils;
+
+/**
+ * Classification and regression using gradient-boosted decision trees.
+ */
+public final class JavaGradientBoostedTreesRunner {
+
+ private static void usage() {
+ System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" +
+ " <Classification/Regression>");
+ System.exit(-1);
+ }
+
+ public static void main(String[] args) {
+ String datapath = "data/mllib/sample_libsvm_data.txt";
+ String algo = "Classification";
+ if (args.length >= 1) {
+ datapath = args[0];
+ }
+ if (args.length >= 2) {
+ algo = args[1];
+ }
+ if (args.length > 2) {
+ usage();
+ }
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
+
+ // Set parameters.
+ // Note: All features are treated as continuous.
+ BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
+ boostingStrategy.setNumIterations(10);
+ boostingStrategy.treeStrategy().setMaxDepth(5);
+
+ if (algo.equals("Classification")) {
+ // Compute the number of classes from the data.
+ Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
+ @Override public Double call(LabeledPoint p) {
+ return p.label();
+ }
+ }).countByValue().size();
+ boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
+
+ // Train a GradientBoosting model for classification.
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD<Double, Double> predictionAndLabel =
+ data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+ @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+ return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainErr =
+ 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
+ @Override public Boolean call(Tuple2<Double, Double> pl) {
+ return !pl._1().equals(pl._2());
+ }
+ }).count() / data.count();
+ System.out.println("Training error: " + trainErr);
+ System.out.println("Learned classification tree model:\n" + model);
+ } else if (algo.equals("Regression")) {
+ // Train a GradientBoosting model for classification.
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD<Double, Double> predictionAndLabel =
+ data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+ @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+ return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainMSE =
+ predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
+ @Override public Double call(Tuple2<Double, Double> pl) {
+ Double diff = pl._1() - pl._2();
+ return diff * diff;
+ }
+ }).reduce(new Function2<Double, Double, Double>() {
+ @Override public Double call(Double a, Double b) {
+ return a + b;
+ }
+ }) / data.count();
+ System.out.println("Training Mean Squared Error: " + trainMSE);
+ System.out.println("Learned regression tree model:\n" + model);
+ } else {
+ usage();
+ }
+
+ sc.stop();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 63f02cf..98f9d16 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -22,11 +22,11 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
+import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
@@ -352,21 +352,11 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = tree.predict(y.features) - y.label
- err * err
- }.mean()
- }
-
- /**
- * Calculates the mean squared error for regression.
- */
private[mllib] def meanSquaredError(
- tree: WeightedEnsembleModel,
+ model: { def predict(features: Vector): Double },
data: RDD[LabeledPoint]): Double = {
data.map { y =>
- val err = tree.predict(y.features) - y.label
+ val err = model.predict(y.features) - y.label
err * err
}.mean()
}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
deleted file mode 100644
index 9b6db01..0000000
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.mllib
-
-import scopt.OptionParser
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.tree.GradientBoosting
-import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
-import org.apache.spark.util.Utils
-
-/**
- * An example runner for Gradient Boosting using decision trees as weak learners. Run with
- * {{{
- * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
- * }}}
- * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
- *
- * Note: This script treats all features as real-valued (not categorical).
- * To include categorical features, modify categoricalFeaturesInfo.
- */
-object GradientBoostedTrees {
-
- case class Params(
- input: String = null,
- testInput: String = "",
- dataFormat: String = "libsvm",
- algo: String = "Classification",
- maxDepth: Int = 5,
- numIterations: Int = 10,
- fracTest: Double = 0.2) extends AbstractParams[Params]
-
- def main(args: Array[String]) {
- val defaultParams = Params()
-
- val parser = new OptionParser[Params]("GradientBoostedTrees") {
- head("GradientBoostedTrees: an example decision tree app.")
- opt[String]("algo")
- .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
- .action((x, c) => c.copy(algo = x))
- opt[Int]("maxDepth")
- .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
- .action((x, c) => c.copy(maxDepth = x))
- opt[Int]("numIterations")
- .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
- .action((x, c) => c.copy(numIterations = x))
- opt[Double]("fracTest")
- .text(s"fraction of data to hold out for testing. If given option testInput, " +
- s"this option is ignored. default: ${defaultParams.fracTest}")
- .action((x, c) => c.copy(fracTest = x))
- opt[String]("testInput")
- .text(s"input path to test dataset. If given, option fracTest is ignored." +
- s" default: ${defaultParams.testInput}")
- .action((x, c) => c.copy(testInput = x))
- opt[String]("<dataFormat>")
- .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
- .action((x, c) => c.copy(dataFormat = x))
- arg[String]("<input>")
- .text("input path to labeled examples")
- .required()
- .action((x, c) => c.copy(input = x))
- checkConfig { params =>
- if (params.fracTest < 0 || params.fracTest > 1) {
- failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
- } else {
- success
- }
- }
- }
-
- parser.parse(args, defaultParams).map { params =>
- run(params)
- }.getOrElse {
- sys.exit(1)
- }
- }
-
- def run(params: Params) {
-
- val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
- val sc = new SparkContext(conf)
-
- println(s"GradientBoostedTrees with parameters:\n$params")
-
- // Load training and test data and cache it.
- val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
- params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
-
- val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
- boostingStrategy.numClassesForClassification = numClasses
- boostingStrategy.numIterations = params.numIterations
- boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
-
- val randomSeed = Utils.random.nextInt()
- if (params.algo == "Classification") {
- val startTime = System.nanoTime()
- val model = GradientBoosting.trainClassifier(training, boostingStrategy)
- val elapsedTime = (System.nanoTime() - startTime) / 1e9
- println(s"Training time: $elapsedTime seconds")
- if (model.totalNumNodes < 30) {
- println(model.toDebugString) // Print full model.
- } else {
- println(model) // Print model summary.
- }
- val trainAccuracy =
- new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
- .precision
- println(s"Train accuracy = $trainAccuracy")
- val testAccuracy =
- new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
- println(s"Test accuracy = $testAccuracy")
- } else if (params.algo == "Regression") {
- val startTime = System.nanoTime()
- val model = GradientBoosting.trainRegressor(training, boostingStrategy)
- val elapsedTime = (System.nanoTime() - startTime) / 1e9
- println(s"Training time: $elapsedTime seconds")
- if (model.totalNumNodes < 30) {
- println(model.toDebugString) // Print full model.
- } else {
- println(model) // Print model summary.
- }
- val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
- println(s"Train mean squared error = $trainMSE")
- val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
- println(s"Test mean squared error = $testMSE")
- }
-
- sc.stop()
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
new file mode 100644
index 0000000..1def8b4
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.tree.GradientBoostedTrees
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
+import org.apache.spark.util.Utils
+
+/**
+ * An example runner for Gradient Boosting using decision trees as weak learners. Run with
+ * {{{
+ * ./bin/run-example mllib.GradientBoostedTreesRunner [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ * To include categorical features, modify categoricalFeaturesInfo.
+ */
+object GradientBoostedTreesRunner {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "Classification",
+ maxDepth: Int = 5,
+ numIterations: Int = 10,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("GradientBoostedTrees") {
+ head("GradientBoostedTrees: an example decision tree app.")
+ opt[String]("algo")
+ .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("numIterations")
+ .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
+ .action((x, c) => c.copy(numIterations = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("<input>")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"GradientBoostedTreesRunner with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
+
+ val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
+ boostingStrategy.treeStrategy.numClassesForClassification = numClasses
+ boostingStrategy.numIterations = params.numIterations
+ boostingStrategy.treeStrategy.maxDepth = params.maxDepth
+
+ val randomSeed = Utils.random.nextInt()
+ if (params.algo == "Classification") {
+ val startTime = System.nanoTime()
+ val model = GradientBoostedTrees.train(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainAccuracy =
+ new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+ .precision
+ println(s"Train accuracy = $trainAccuracy")
+ val testAccuracy =
+ new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+ println(s"Test accuracy = $testAccuracy")
+ } else if (params.algo == "Regression") {
+ val startTime = System.nanoTime()
+ val model = GradientBoostedTrees.train(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
+ println(s"Train mean squared error = $trainMSE")
+ val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
+ println(s"Test mean squared error = $testMSE")
+ }
+
+ sc.stop()
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 78acc17..3d91867 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -58,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return DecisionTreeModel that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
+ def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
- val rfModel = rf.train(input)
- rfModel.weakHypotheses(0)
+ val rfModel = rf.run(input)
+ rfModel.trees(0)
}
+ /**
+ * Trains a decision tree model over an RDD. This is deprecated because it hides the static
+ * methods with the same name in Java.
+ */
+ @deprecated("Please use DecisionTree.run instead.", "1.2.0")
+ def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input)
}
object DecisionTree extends Serializable with Logging {
@@ -86,7 +92,7 @@ object DecisionTree extends Serializable with Logging {
* @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
@@ -112,7 +118,7 @@ object DecisionTree extends Serializable with Logging {
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth)
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
@@ -140,7 +146,7 @@ object DecisionTree extends Serializable with Logging {
maxDepth: Int,
numClassesForClassification: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
@@ -177,7 +183,7 @@ object DecisionTree extends Serializable with Logging {
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
new file mode 100644
index 0000000..cb4ddfc
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.impurity.Variance
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * :: Experimental ::
+ * A class that implements Stochastic Gradient Boosting for regression and binary classification.
+ *
+ * The implementation is based upon:
+ * J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes:
+ * - This currently can be run with several loss functions. However, only SquaredError is
+ * fully supported. Specifically, the loss function should be used to compute the gradient
+ * (to re-label training instances on each iteration) and to weight weak hypotheses.
+ * Currently, gradients are computed correctly for the available loss functions,
+ * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
+ * Running with those losses will likely behave reasonably, but lacks the same guarantees.
+ *
+ * @param boostingStrategy Parameters for the gradient boosting algorithm.
+ */
+@Experimental
+class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
+ extends Serializable with Logging {
+
+ /**
+ * Method to train a gradient boosting model
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return a gradient boosted trees model that can be used for prediction
+ */
+ def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
+ val algo = boostingStrategy.treeStrategy.algo
+ algo match {
+ case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
+ case Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
+ val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ GradientBoostedTrees.boost(remappedInput, boostingStrategy)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+ }
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
+ */
+ def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
+ run(input.rdd)
+ }
+}
+
+
+object GradientBoostedTrees extends Logging {
+
+ /**
+ * Method to train a gradient boosting model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return a gradient boosted trees model that can be used for prediction
+ */
+ def train(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+ new GradientBoostedTrees(boostingStrategy).run(input)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]]
+ */
+ def train(
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+ train(input.rdd, boostingStrategy)
+ }
+
+ /**
+ * Internal method for performing regression using trees as base learners.
+ * @param input training dataset
+ * @param boostingStrategy boosting parameters
+ * @return a gradient boosted trees model that can be used for prediction
+ */
+ private def boost(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+
+ val timer = new TimeTracker()
+ timer.start("total")
+ timer.start("init")
+
+ boostingStrategy.assertValid()
+
+ // Initialize gradient boosting parameters
+ val numIterations = boostingStrategy.numIterations
+ val baseLearners = new Array[DecisionTreeModel](numIterations)
+ val baseLearnerWeights = new Array[Double](numIterations)
+ val loss = boostingStrategy.loss
+ val learningRate = boostingStrategy.learningRate
+ // Prepare strategy for individual trees, which use regression with variance impurity.
+ val treeStrategy = boostingStrategy.treeStrategy.copy
+ treeStrategy.algo = Regression
+ treeStrategy.impurity = Variance
+ treeStrategy.assertValid()
+
+ // Cache input
+ if (input.getStorageLevel == StorageLevel.NONE) {
+ input.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ timer.stop("init")
+
+ logDebug("##########")
+ logDebug("Building tree 0")
+ logDebug("##########")
+ var data = input
+
+ // Initialize tree
+ timer.start("building tree 0")
+ val firstTreeModel = new DecisionTree(treeStrategy).run(data)
+ baseLearners(0) = firstTreeModel
+ baseLearnerWeights(0) = 1.0
+ val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
+ logDebug("error of gbt = " + loss.computeError(startingModel, input))
+ // Note: A model of type regression is used since we require raw prediction
+ timer.stop("building tree 0")
+
+ // psuedo-residual for second iteration
+ data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
+ point.features))
+
+ var m = 1
+ while (m < numIterations) {
+ timer.start(s"building tree $m")
+ logDebug("###################################################")
+ logDebug("Gradient boosting tree iteration " + m)
+ logDebug("###################################################")
+ val model = new DecisionTree(treeStrategy).run(data)
+ timer.stop(s"building tree $m")
+ // Create partial model
+ baseLearners(m) = model
+ // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+ // Technically, the weight should be optimized for the particular loss.
+ // However, the behavior should be reasonable, though not optimal.
+ baseLearnerWeights(m) = learningRate
+ // Note: A model of type regression is used since we require raw prediction
+ val partialModel = new GradientBoostedTreesModel(
+ Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
+ logDebug("error of gbt = " + loss.computeError(partialModel, input))
+ // Update data with pseudo-residuals
+ data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
+ point.features))
+ m += 1
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
deleted file mode 100644
index f729344..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
+++ /dev/null
@@ -1,249 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree
-
-import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.configuration.BoostingStrategy
-import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
-import org.apache.spark.mllib.tree.impl.TimeTracker
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-
-/**
- * :: Experimental ::
- * A class that implements Stochastic Gradient Boosting
- * for regression and binary classification problems.
- *
- * The implementation is based upon:
- * J.H. Friedman. "Stochastic Gradient Boosting." 1999.
- *
- * Notes:
- * - This currently can be run with several loss functions. However, only SquaredError is
- * fully supported. Specifically, the loss function should be used to compute the gradient
- * (to re-label training instances on each iteration) and to weight weak hypotheses.
- * Currently, gradients are computed correctly for the available loss functions,
- * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
- * Running with those losses will likely behave reasonably, but lacks the same guarantees.
- *
- * @param boostingStrategy Parameters for the gradient boosting algorithm
- */
-@Experimental
-class GradientBoosting (
- private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
-
- boostingStrategy.weakLearnerParams.algo = Regression
- boostingStrategy.weakLearnerParams.impurity = impurity.Variance
-
- // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
- boostingStrategy.weakLearnerParams.numClassesForClassification =
- boostingStrategy.numClassesForClassification
-
- boostingStrategy.assertValid()
-
- /**
- * Method to train a gradient boosting model
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return WeightedEnsembleModel that can be used for prediction
- */
- def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
- val algo = boostingStrategy.algo
- algo match {
- case Regression => GradientBoosting.boost(input, boostingStrategy)
- case Classification =>
- // Map labels to -1, +1 so binary classification can be treated as regression.
- val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoosting.boost(remappedInput, boostingStrategy)
- case _ =>
- throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
- }
- }
-
-}
-
-
-object GradientBoosting extends Logging {
-
- /**
- * Method to train a gradient boosting model.
- *
- * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
- * is recommended to clearly specify regression.
- * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
- * is recommended to clearly specify regression.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param boostingStrategy Configuration options for the boosting algorithm.
- * @return WeightedEnsembleModel that can be used for prediction
- */
- def train(
- input: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- new GradientBoosting(boostingStrategy).train(input)
- }
-
- /**
- * Method to train a gradient boosting classification model.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param boostingStrategy Configuration options for the boosting algorithm.
- * @return WeightedEnsembleModel that can be used for prediction
- */
- def trainClassifier(
- input: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- val algo = boostingStrategy.algo
- require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
- new GradientBoosting(boostingStrategy).train(input)
- }
-
- /**
- * Method to train a gradient boosting regression model.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param boostingStrategy Configuration options for the boosting algorithm.
- * @return WeightedEnsembleModel that can be used for prediction
- */
- def trainRegressor(
- input: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- val algo = boostingStrategy.algo
- require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
- new GradientBoosting(boostingStrategy).train(input)
- }
-
- /**
- * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]]
- */
- def train(
- input: JavaRDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- train(input.rdd, boostingStrategy)
- }
-
- /**
- * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
- */
- def trainClassifier(
- input: JavaRDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- trainClassifier(input.rdd, boostingStrategy)
- }
-
- /**
- * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
- */
- def trainRegressor(
- input: JavaRDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- trainRegressor(input.rdd, boostingStrategy)
- }
-
- /**
- * Internal method for performing regression using trees as base learners.
- * @param input training dataset
- * @param boostingStrategy boosting parameters
- * @return
- */
- private def boost(
- input: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
-
- val timer = new TimeTracker()
- timer.start("total")
- timer.start("init")
-
- // Initialize gradient boosting parameters
- val numIterations = boostingStrategy.numIterations
- val baseLearners = new Array[DecisionTreeModel](numIterations)
- val baseLearnerWeights = new Array[Double](numIterations)
- val loss = boostingStrategy.loss
- val learningRate = boostingStrategy.learningRate
- val strategy = boostingStrategy.weakLearnerParams
-
- // Cache input
- if (input.getStorageLevel == StorageLevel.NONE) {
- input.persist(StorageLevel.MEMORY_AND_DISK)
- }
-
- timer.stop("init")
-
- logDebug("##########")
- logDebug("Building tree 0")
- logDebug("##########")
- var data = input
-
- // Initialize tree
- timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(strategy).train(data)
- baseLearners(0) = firstTreeModel
- baseLearnerWeights(0) = 1.0
- val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
- Sum)
- logDebug("error of gbt = " + loss.computeError(startingModel, input))
- // Note: A model of type regression is used since we require raw prediction
- timer.stop("building tree 0")
-
- // psuedo-residual for second iteration
- data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
- point.features))
-
- var m = 1
- while (m < numIterations) {
- timer.start(s"building tree $m")
- logDebug("###################################################")
- logDebug("Gradient boosting tree iteration " + m)
- logDebug("###################################################")
- val model = new DecisionTree(strategy).train(data)
- timer.stop(s"building tree $m")
- // Create partial model
- baseLearners(m) = model
- // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
- // Technically, the weight should be optimized for the particular loss.
- // However, the behavior should be reasonable, though not optimal.
- baseLearnerWeights(m) = learningRate
- // Note: A model of type regression is used since we require raw prediction
- val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
- baseLearnerWeights.slice(0, m + 1), Regression, Sum)
- logDebug("error of gbt = " + loss.computeError(partialModel, input))
- // Update data with pseudo-residuals
- data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
- point.features))
- m += 1
- }
-
- timer.stop("total")
-
- logInfo("Internal timing for DecisionTree:")
- logInfo(s"$timer")
-
- new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
-
- }
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 9683916..ca0b6ee 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -17,18 +17,18 @@
package org.apache.spark.mllib.tree
-import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.collection.JavaConverters._
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
-import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache,
+ TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.Impurities
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
@@ -79,9 +79,9 @@ private class RandomForest (
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
+ def run(input: RDD[LabeledPoint]): RandomForestModel = {
val timer = new TimeTracker()
@@ -212,8 +212,7 @@ private class RandomForest (
}
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
- val treeWeights = Array.fill[Double](numTrees)(1.0)
- new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
+ new RandomForestModel(strategy.algo, trees)
}
}
@@ -234,18 +233,18 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
- seed: Int): WeightedEnsembleModel = {
+ seed: Int): RandomForestModel = {
require(strategy.algo == Classification,
s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
- rf.train(input)
+ rf.run(input)
}
/**
@@ -272,7 +271,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
@@ -283,7 +282,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
+ seed: Int = Utils.random.nextInt()): RandomForestModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Classification, impurityType, maxDepth,
numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
@@ -302,7 +301,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int): WeightedEnsembleModel = {
+ seed: Int): RandomForestModel = {
trainClassifier(input.rdd, numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -322,18 +321,18 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
- seed: Int): WeightedEnsembleModel = {
+ seed: Int): RandomForestModel = {
require(strategy.algo == Regression,
s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
- rf.train(input)
+ rf.run(input)
}
/**
@@ -359,7 +358,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
@@ -369,7 +368,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
+ seed: Int = Utils.random.nextInt()): RandomForestModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Regression, impurityType, maxDepth,
0, maxBins, Sort, categoricalFeaturesInfo)
@@ -387,7 +386,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int): WeightedEnsembleModel = {
+ seed: Int): RandomForestModel = {
trainRegressor(input.rdd,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -479,5 +478,4 @@ object RandomForest extends Serializable with Logging {
3 * totalBins
}
}
-
}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index abbda04..e703adb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -25,57 +25,39 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
/**
* :: Experimental ::
- * Stores all the configuration options for the boosting algorithms
- * @param algo Learning goal. Supported:
- * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
- * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]].
+ *
+ * @param treeStrategy Parameters for the tree algorithm. We support regression and binary
+ * classification for boosting. Impurity setting will be ignored.
+ * @param loss Loss function used for minimization during gradient boosting.
* @param numIterations Number of iterations of boosting. In other words, the number of
* weak hypotheses used in the final model.
- * @param loss Loss function used for minimization during gradient boosting.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
- * @param numClassesForClassification Number of classes for classification.
- * (Ignored for regression.)
- * This setting overrides any setting in [[weakLearnerParams]].
- * Default value is 2 (binary classification).
- * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
- * supported.
*/
@Experimental
case class BoostingStrategy(
// Required boosting parameters
- @BeanProperty var algo: Algo,
- @BeanProperty var numIterations: Int,
+ @BeanProperty var treeStrategy: Strategy,
@BeanProperty var loss: Loss,
// Optional boosting parameters
- @BeanProperty var learningRate: Double = 0.1,
- @BeanProperty var numClassesForClassification: Int = 2,
- @BeanProperty var weakLearnerParams: Strategy) extends Serializable {
-
- // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
- weakLearnerParams.numClassesForClassification = numClassesForClassification
-
- /**
- * Sets Algorithm using a String.
- */
- def setAlgo(algo: String): Unit = algo match {
- case "Classification" => setAlgo(Classification)
- case "Regression" => setAlgo(Regression)
- }
+ @BeanProperty var numIterations: Int = 100,
+ @BeanProperty var learningRate: Double = 0.1) extends Serializable {
/**
* Check validity of parameters.
* Throws exception if invalid.
*/
private[tree] def assertValid(): Unit = {
- algo match {
+ treeStrategy.algo match {
case Classification =>
- require(numClassesForClassification == 2)
+ require(treeStrategy.numClassesForClassification == 2,
+ "Only binary classification is supported for boosting.")
case Regression =>
// nothing
case _ =>
throw new IllegalArgumentException(
- s"BoostingStrategy given invalid algo parameter: $algo." +
+ s"BoostingStrategy given invalid algo parameter: ${treeStrategy.algo}." +
s" Valid settings are: Classification, Regression.")
}
require(learningRate > 0 && learningRate <= 1,
@@ -94,14 +76,14 @@ object BoostingStrategy {
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: String): BoostingStrategy = {
- val treeStrategy = Strategy.defaultStrategy("Regression")
+ val treeStrategy = Strategy.defaultStrategy(algo)
treeStrategy.maxDepth = 3
algo match {
case "Classification" =>
- new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy)
+ treeStrategy.numClassesForClassification = 2
+ new BoostingStrategy(treeStrategy, LogLoss)
case "Regression" =>
- new BoostingStrategy(Algo.withName(algo), 100, SquaredError,
- weakLearnerParams = treeStrategy)
+ new BoostingStrategy(treeStrategy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
index 82889dc..b5bf732 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
@@ -17,14 +17,10 @@
package org.apache.spark.mllib.tree.configuration
-import org.apache.spark.annotation.DeveloperApi
-
/**
- * :: Experimental ::
* Enum to select ensemble combining strategy for base learners
*/
-@DeveloperApi
-object EnsembleCombiningStrategy extends Enumeration {
+private[tree] object EnsembleCombiningStrategy extends Enumeration {
type EnsembleCombiningStrategy = Value
- val Sum, Average = Value
+ val Average, Sum, Vote = Value
}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index b5b1f82..d75f384 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -157,6 +157,13 @@ class Strategy (
require(maxMemoryInMB <= 10240,
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
}
+
+ /** Returns a shallow copy of this instance. */
+ def copy: Strategy = {
+ new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
+ quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
+ maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
+ }
}
@Experimental
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index d111ffe..e828866 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
/**
@@ -42,7 +42,7 @@ object AbsoluteError extends Loss {
* @return Loss gradient
*/
override def gradient(
- model: WeightedEnsembleModel,
+ model: TreeEnsembleModel,
point: LabeledPoint): Double = {
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
}
@@ -55,7 +55,7 @@ object AbsoluteError extends Loss {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return
*/
- override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
val sumOfAbsolutes = data.map { y =>
val err = model.predict(y.features) - y.label
math.abs(err)
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 6f3d434..8b8adb4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
/**
@@ -42,7 +42,7 @@ object LogLoss extends Loss {
* @return Loss gradient
*/
override def gradient(
- model: WeightedEnsembleModel,
+ model: TreeEnsembleModel,
point: LabeledPoint): Double = {
val prediction = model.predict(point.features)
1.0 / (1.0 + math.exp(-prediction)) - point.label
@@ -56,7 +56,7 @@ object LogLoss extends Loss {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return
*/
- override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count()
wrongPredictions / data.count
}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index 5580866..4bca903 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
/**
@@ -36,7 +36,7 @@ trait Loss extends Serializable {
* @return Loss gradient.
*/
def gradient(
- model: WeightedEnsembleModel,
+ model: TreeEnsembleModel,
point: LabeledPoint): Double
/**
@@ -47,6 +47,6 @@ trait Loss extends Serializable {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return
*/
- def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double
+ def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 4349fef..cfe395b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
/**
@@ -43,7 +43,7 @@ object SquaredError extends Loss {
* @return Loss gradient
*/
override def gradient(
- model: WeightedEnsembleModel,
+ model: TreeEnsembleModel,
point: LabeledPoint): Double = {
model.predict(point.features) - point.label
}
@@ -56,7 +56,7 @@ object SquaredError extends Loss {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return
*/
- override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = model.predict(y.features) - y.label
err * err
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index ac4d02e..a576096 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,11 +17,11 @@
package org.apache.spark.mllib.tree.model
-import org.apache.spark.api.java.JavaRDD
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.Vector
/**
* :: Experimental ::
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
deleted file mode 100644
index 7b052d9..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
+++ /dev/null
@@ -1,158 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
-import org.apache.spark.rdd.RDD
-
-import scala.collection.mutable
-
-@Experimental
-class WeightedEnsembleModel(
- val weakHypotheses: Array[DecisionTreeModel],
- val weakHypothesisWeights: Array[Double],
- val algo: Algo,
- val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
-
- require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" +
- s". Number of weakHypotheses = $weakHypotheses")
-
- /**
- * Predict values for a single data point using the model trained.
- *
- * @param features array representing a single data point
- * @return predicted category from the trained model
- */
- private def predictRaw(features: Vector): Double = {
- val treePredictions = weakHypotheses.map(learner => learner.predict(features))
- if (numWeakHypotheses == 1){
- treePredictions(0)
- } else {
- var prediction = treePredictions(0)
- var index = 1
- while (index < numWeakHypotheses) {
- prediction += weakHypothesisWeights(index) * treePredictions(index)
- index += 1
- }
- prediction
- }
- }
-
- /**
- * Predict values for a single data point using the model trained.
- *
- * @param features array representing a single data point
- * @return predicted category from the trained model
- */
- private def predictBySumming(features: Vector): Double = {
- algo match {
- case Regression => predictRaw(features)
- case Classification => {
- // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
- if (predictRaw(features) > 0 ) 1.0 else 0.0
- }
- case _ => throw new IllegalArgumentException(
- s"WeightedEnsembleModel given unknown algo parameter: $algo.")
- }
- }
-
- /**
- * Predict values for a single data point.
- *
- * @param features array representing a single data point
- * @return Double prediction from the trained model
- */
- private def predictByAveraging(features: Vector): Double = {
- algo match {
- case Classification =>
- val predictionToCount = new mutable.HashMap[Int, Int]()
- weakHypotheses.foreach { learner =>
- val prediction = learner.predict(features).toInt
- predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
- }
- predictionToCount.maxBy(_._2)._1
- case Regression =>
- weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size
- }
- }
-
-
- /**
- * Predict values for a single data point using the model trained.
- *
- * @param features array representing a single data point
- * @return predicted category from the trained model
- */
- def predict(features: Vector): Double = {
- combiningStrategy match {
- case Sum => predictBySumming(features)
- case Average => predictByAveraging(features)
- case _ => throw new IllegalArgumentException(
- s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.")
- }
- }
-
- /**
- * Predict values for the given data set.
- *
- * @param features RDD representing data points to be predicted
- * @return RDD[Double] where each entry contains the corresponding prediction
- */
- def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
-
- /**
- * Print a summary of the model.
- */
- override def toString: String = {
- algo match {
- case Classification =>
- s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n"
- case Regression =>
- s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n"
- case _ => throw new IllegalArgumentException(
- s"WeightedEnsembleModel given unknown algo parameter: $algo.")
- }
- }
-
- /**
- * Print the full model to a string.
- */
- def toDebugString: String = {
- val header = toString + "\n"
- header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) =>
- s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
- }.fold("")(_ + _)
- }
-
- /**
- * Get number of trees in forest.
- */
- def numWeakHypotheses: Int = weakHypotheses.size
-
- // TODO: Remove these helpers methods once class is generalized to support any base learning
- // algorithms.
-
- /**
- * Get total number of nodes, summed over all trees in the forest.
- */
- def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
new file mode 100644
index 0000000..2299711
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import scala.collection.mutable
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ * Represents a random forest model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ */
+@Experimental
+class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
+ extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
+ combiningStrategy = if (algo == Classification) Vote else Average) {
+
+ require(trees.forall(_.algo == algo))
+}
+
+/**
+ * :: Experimental ::
+ * Represents a gradient boosted trees model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ * @param treeWeights tree ensemble weights
+ */
+@Experimental
+class GradientBoostedTreesModel(
+ override val algo: Algo,
+ override val trees: Array[DecisionTreeModel],
+ override val treeWeights: Array[Double])
+ extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) {
+
+ require(trees.size == treeWeights.size)
+}
+
+/**
+ * Represents a tree ensemble model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ * @param treeWeights tree ensemble weights
+ * @param combiningStrategy strategy for combining the predictions, not used for regression.
+ */
+private[tree] sealed class TreeEnsembleModel(
+ protected val algo: Algo,
+ protected val trees: Array[DecisionTreeModel],
+ protected val treeWeights: Array[Double],
+ protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
+
+ require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.")
+
+ private val sumWeights = math.max(treeWeights.sum, 1e-15)
+
+ /**
+ * Predicts for a single data point using the weighted sum of ensemble predictions.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ private def predictBySumming(features: Vector): Double = {
+ val treePredictions = trees.map(_.predict(features))
+ blas.ddot(numTrees, treePredictions, 1, treeWeights, 1)
+ }
+
+ /**
+ * Classifies a single data point based on (weighted) majority votes.
+ */
+ private def predictByVoting(features: Vector): Double = {
+ val votes = mutable.Map.empty[Int, Double]
+ trees.view.zip(treeWeights).foreach { case (tree, weight) =>
+ val prediction = tree.predict(features).toInt
+ votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
+ }
+ votes.maxBy(_._2)._1
+ }
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ def predict(features: Vector): Double = {
+ (algo, combiningStrategy) match {
+ case (Regression, Sum) =>
+ predictBySumming(features)
+ case (Regression, Average) =>
+ predictBySumming(features) / sumWeights
+ case (Classification, Sum) => // binary classification
+ val prediction = predictBySumming(features)
+ // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
+ if (prediction > 0.0) 1.0 else 0.0
+ case (Classification, Vote) =>
+ predictByVoting(features)
+ case _ =>
+ throw new IllegalArgumentException(
+ "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " +
+ s"($algo, $combiningStrategy).")
+ }
+ }
+
+ /**
+ * Predict values for the given data set.
+ *
+ * @param features RDD representing data points to be predicted
+ * @return RDD[Double] where each entry contains the corresponding prediction
+ */
+ def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
+
+ /**
+ * Java-friendly version of [[org.apache.spark.mllib.tree.model.TreeEnsembleModel#predict]].
+ */
+ def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = {
+ predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
+ }
+
+ /**
+ * Print a summary of the model.
+ */
+ override def toString: String = {
+ algo match {
+ case Classification =>
+ s"TreeEnsembleModel classifier with $numTrees trees\n"
+ case Regression =>
+ s"TreeEnsembleModel regressor with $numTrees trees\n"
+ case _ => throw new IllegalArgumentException(
+ s"TreeEnsembleModel given unknown algo parameter: $algo.")
+ }
+ }
+
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + trees.zipWithIndex.map { case (tree, treeIndex) =>
+ s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+ }.fold("")(_ + _)
+ }
+
+ /**
+ * Get number of trees in forest.
+ */
+ def numTrees: Int = trees.size
+
+ /**
+ * Get total number of nodes, summed over all trees in the forest.
+ */
+ def totalNumNodes: Int = trees.map(_.numNodes).sum
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
index 2c281a1..9925aae 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -74,7 +74,7 @@ public class JavaDecisionTreeSuite implements Serializable {
maxBins, categoricalFeaturesInfo);
DecisionTree learner = new DecisionTree(strategy);
- DecisionTreeModel model = learner.train(rdd.rdd());
+ DecisionTreeModel model = learner.run(rdd.rdd());
int numCorrect = validatePrediction(arr, model);
Assert.assertTrue(numCorrect == rdd.count());
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org