You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/20 09:49:06 UTC

[1/2] spark git commit: [SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc

Repository: spark
Updated Branches:
  refs/heads/master e216ffaea -> 15cacc812


http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
index effb7b8..8972c22 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree
 
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
 import org.apache.spark.util.StatCounter
 
 import scala.collection.mutable
@@ -48,7 +48,7 @@ object EnsembleTestHelper {
   }
 
   def validateClassifier(
-      model: WeightedEnsembleModel,
+      model: TreeEnsembleModel,
       input: Seq[LabeledPoint],
       requiredAccuracy: Double) {
     val predictions = input.map(x => model.predict(x.features))
@@ -60,17 +60,27 @@ object EnsembleTestHelper {
       s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
   }
 
+  /**
+   * Validates a tree ensemble model for regression.
+   */
   def validateRegressor(
-      model: WeightedEnsembleModel,
+      model: TreeEnsembleModel,
       input: Seq[LabeledPoint],
-      requiredMSE: Double) {
+      required: Double,
+      metricName: String = "mse") {
     val predictions = input.map(x => model.predict(x.features))
-    val squaredError = predictions.zip(input).map { case (prediction, expected) =>
-      val err = prediction - expected.label
-      err * err
-    }.sum
-    val mse = squaredError / input.length
-    assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+    val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) =>
+      prediction - label
+    }
+    val metric = metricName match {
+      case "mse" =>
+        errors.map(err => err * err).sum / errors.size
+      case "mae" =>
+        errors.map(math.abs).sum / errors.size
+    }
+
+    assert(metric <= required,
+      s"validateRegressor calculated $metricName $metric but required $required.")
   }
 
   def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
new file mode 100644
index 0000000..f3f8eff
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
+import org.apache.spark.mllib.tree.impurity.Variance
+import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss}
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+/**
+ * Test suite for [[GradientBoostedTrees]].
+ */
+class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
+
+  test("Regression with continuous features: SquaredError") {
+    GradientBoostedTreesSuite.testCombinations.foreach {
+      case (numIterations, learningRate, subsamplingRate) =>
+        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
+        val rdd = sc.parallelize(arr, 2)
+
+        val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+          categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
+        val boostingStrategy =
+          new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
+
+        val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+        assert(gbt.trees.size === numIterations)
+        EnsembleTestHelper.validateRegressor(gbt, arr, 0.03)
+
+        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+        val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+        // Make sure trees are the same.
+        assert(gbt.trees.head.toString == dt.toString)
+    }
+  }
+
+  test("Regression with continuous features: Absolute Error") {
+    GradientBoostedTreesSuite.testCombinations.foreach {
+      case (numIterations, learningRate, subsamplingRate) =>
+        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
+        val rdd = sc.parallelize(arr, 2)
+
+        val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+          categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
+        val boostingStrategy =
+          new BoostingStrategy(treeStrategy, AbsoluteError, numIterations, learningRate)
+
+        val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+        assert(gbt.trees.size === numIterations)
+        EnsembleTestHelper.validateRegressor(gbt, arr, 0.85, "mae")
+
+        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+        val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+        // Make sure trees are the same.
+        assert(gbt.trees.head.toString == dt.toString)
+    }
+  }
+
+  test("Binary classification with continuous features: Log Loss") {
+    GradientBoostedTreesSuite.testCombinations.foreach {
+      case (numIterations, learningRate, subsamplingRate) =>
+        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
+        val rdd = sc.parallelize(arr, 2)
+
+        val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2,
+          numClassesForClassification = 2, categoricalFeaturesInfo = Map.empty,
+          subsamplingRate = subsamplingRate)
+        val boostingStrategy =
+          new BoostingStrategy(treeStrategy, LogLoss, numIterations, learningRate)
+
+        val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+        assert(gbt.trees.size === numIterations)
+        EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
+
+        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+        val ensembleStrategy = treeStrategy.copy
+        ensembleStrategy.algo = Regression
+        ensembleStrategy.impurity = Variance
+        val dt = DecisionTree.train(remappedInput, ensembleStrategy)
+
+        // Make sure trees are the same.
+        assert(gbt.trees.head.toString == dt.toString)
+    }
+  }
+
+}
+
+object GradientBoostedTreesSuite {
+
+  // Combinations for estimators, learning rates and subsamplingRate
+  val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
deleted file mode 100644
index 84de401..0000000
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
-import org.apache.spark.mllib.tree.impurity.Variance
-import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
-
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-
-/**
- * Test suite for [[GradientBoosting]].
- */
-class GradientBoostingSuite extends FunSuite with MLlibTestSparkContext {
-
-  test("Regression with continuous features: SquaredError") {
-    GradientBoostingSuite.testCombinations.foreach {
-      case (numIterations, learningRate, subsamplingRate) =>
-        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
-        val rdd = sc.parallelize(arr)
-        val categoricalFeaturesInfo = Map.empty[Int, Int]
-
-        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
-        val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
-          numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
-          subsamplingRate = subsamplingRate)
-
-        val dt = DecisionTree.train(remappedInput, treeStrategy)
-
-        val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
-          learningRate, 1, treeStrategy)
-
-        val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
-        assert(gbt.weakHypotheses.size === numIterations)
-        val gbtTree = gbt.weakHypotheses(0)
-
-        EnsembleTestHelper.validateRegressor(gbt, arr, 0.03)
-
-        // Make sure trees are the same.
-        assert(gbtTree.toString == dt.toString)
-    }
-  }
-
-  test("Regression with continuous features: Absolute Error") {
-    GradientBoostingSuite.testCombinations.foreach {
-      case (numIterations, learningRate, subsamplingRate) =>
-        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
-        val rdd = sc.parallelize(arr)
-        val categoricalFeaturesInfo = Map.empty[Int, Int]
-
-        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
-        val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
-          numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
-          subsamplingRate = subsamplingRate)
-
-        val dt = DecisionTree.train(remappedInput, treeStrategy)
-
-        val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
-          learningRate, numClassesForClassification = 2, treeStrategy)
-
-        val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
-        assert(gbt.weakHypotheses.size === numIterations)
-        val gbtTree = gbt.weakHypotheses(0)
-
-        EnsembleTestHelper.validateRegressor(gbt, arr, 0.03)
-
-        // Make sure trees are the same.
-        assert(gbtTree.toString == dt.toString)
-    }
-  }
-
-  test("Binary classification with continuous features: Log Loss") {
-    GradientBoostingSuite.testCombinations.foreach {
-      case (numIterations, learningRate, subsamplingRate) =>
-        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
-        val rdd = sc.parallelize(arr)
-        val categoricalFeaturesInfo = Map.empty[Int, Int]
-
-        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
-        val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
-          numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
-          subsamplingRate = subsamplingRate)
-
-        val dt = DecisionTree.train(remappedInput, treeStrategy)
-
-        val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss,
-          learningRate, numClassesForClassification = 2, treeStrategy)
-
-        val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy)
-        assert(gbt.weakHypotheses.size === numIterations)
-        val gbtTree = gbt.weakHypotheses(0)
-
-        EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
-
-        // Make sure trees are the same.
-        assert(gbtTree.toString == dt.toString)
-    }
-  }
-
-}
-
-object GradientBoostingSuite {
-
-  // Combinations for estimators, learning rates and subsamplingRate
-  val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 2734e08..90a8c2d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -41,8 +41,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
 
     val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
       featureSubsetStrategy = "auto", seed = 123)
-    assert(rf.weakHypotheses.size === 1)
-    val rfTree = rf.weakHypotheses(0)
+    assert(rf.trees.size === 1)
+    val rfTree = rf.trees(0)
 
     val dt = DecisionTree.train(rdd, strategy)
 
@@ -65,7 +65,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
     " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
     val categoricalFeaturesInfo = Map.empty[Int, Int]
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
-      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+      useNodeIdCache = true)
     binaryClassificationTestWithContinuousFeatures(strategy)
   }
 
@@ -76,8 +77,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
 
     val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
       featureSubsetStrategy = "auto", seed = 123)
-    assert(rf.weakHypotheses.size === 1)
-    val rfTree = rf.weakHypotheses(0)
+    assert(rf.trees.size === 1)
+    val rfTree = rf.trees(0)
 
     val dt = DecisionTree.train(rdd, strategy)
 
@@ -175,7 +176,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
   test("Binary classification with continuous features and node Id cache: subsampling features") {
     val categoricalFeaturesInfo = Map.empty[Int, Int]
     val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
-      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+      useNodeIdCache = true)
     binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] spark git commit: [SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc

Posted by me...@apache.org.
[SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc

There are some inconsistencies in the gradient boosting APIs. The target is a general boosting meta-algorithm, but the implementation is attached to trees. This was partially due to the delay of SPARK-1856. But for the 1.2 release, we should make the APIs consistent.

1. WeightedEnsembleModel -> private[tree] TreeEnsembleModel and renamed members accordingly.
1. GradientBoosting -> GradientBoostedTrees
1. Add RandomForestModel and GradientBoostedTreesModel and hide CombiningStrategy
1. Slightly refactored TreeEnsembleModel (Vote takes weights into consideration.)
1. Remove `trainClassifier` and `trainRegressor` from `GradientBoostedTrees` because they are the same as `train`
1. Rename class `train` method to `run` because it hides the static methods with the same name in Java. Deprecated `DecisionTree.train` class method.
1. Simplify BoostingStrategy and make sure the input strategy is not modified. Users should put algo and numClasses in treeStrategy. We create ensembleStrategy inside boosting.
1. Fix a bug in GradientBoostedTreesSuite with AbsoluteError
1. doc updates

manishamde jkbradley

Author: Xiangrui Meng <me...@databricks.com>

Closes #3374 from mengxr/SPARK-4486 and squashes the following commits:

7097251 [Xiangrui Meng] address joseph's comments
98dea09 [Xiangrui Meng] address manish's comments
4aae3b7 [Xiangrui Meng] add RandomForestModel and GradientBoostedTreesModel, hide CombiningStrategy
ea4c467 [Xiangrui Meng] fix unit tests
751da4e [Xiangrui Meng] rename class method train -> run
19030a5 [Xiangrui Meng] update boosting public APIs


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/15cacc81
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/15cacc81
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/15cacc81

Branch: refs/heads/master
Commit: 15cacc81240eed8834b4730c5c6dc3238f003465
Parents: e216ffa
Author: Xiangrui Meng <me...@databricks.com>
Authored: Thu Nov 20 00:48:59 2014 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Nov 20 00:48:59 2014 -0800

----------------------------------------------------------------------
 .../mllib/JavaGradientBoostedTrees.java         | 126 ----------
 .../mllib/JavaGradientBoostedTreesRunner.java   | 126 ++++++++++
 .../examples/mllib/DecisionTreeRunner.scala     |  18 +-
 .../examples/mllib/GradientBoostedTrees.scala   | 146 -----------
 .../mllib/GradientBoostedTreesRunner.scala      | 146 +++++++++++
 .../apache/spark/mllib/tree/DecisionTree.scala  |  20 +-
 .../spark/mllib/tree/GradientBoostedTrees.scala | 192 ++++++++++++++
 .../spark/mllib/tree/GradientBoosting.scala     | 249 -------------------
 .../apache/spark/mllib/tree/RandomForest.scala  |  40 ++-
 .../tree/configuration/BoostingStrategy.scala   |  50 ++--
 .../EnsembleCombiningStrategy.scala             |   8 +-
 .../mllib/tree/configuration/Strategy.scala     |   7 +
 .../spark/mllib/tree/loss/AbsoluteError.scala   |   6 +-
 .../apache/spark/mllib/tree/loss/LogLoss.scala  |   6 +-
 .../org/apache/spark/mllib/tree/loss/Loss.scala |   6 +-
 .../spark/mllib/tree/loss/SquaredError.scala    |   6 +-
 .../mllib/tree/model/DecisionTreeModel.scala    |   4 +-
 .../tree/model/WeightedEnsembleModel.scala      | 158 ------------
 .../mllib/tree/model/treeEnsembleModels.scala   | 178 +++++++++++++
 .../spark/mllib/tree/JavaDecisionTreeSuite.java |   2 +-
 .../spark/mllib/tree/EnsembleTestHelper.scala   |  30 ++-
 .../mllib/tree/GradientBoostedTreesSuite.scala  | 117 +++++++++
 .../mllib/tree/GradientBoostingSuite.scala      | 126 ----------
 .../spark/mllib/tree/RandomForestSuite.scala    |  14 +-
 24 files changed, 863 insertions(+), 918 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
deleted file mode 100644
index 1af2067..0000000
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.mllib;
-
-import scala.Tuple2;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.tree.GradientBoosting;
-import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
-import org.apache.spark.mllib.util.MLUtils;
-
-/**
- * Classification and regression using gradient-boosted decision trees.
- */
-public final class JavaGradientBoostedTrees {
-
-  private static void usage() {
-    System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
-        " <Classification/Regression>");
-    System.exit(-1);
-  }
-
-  public static void main(String[] args) {
-    String datapath = "data/mllib/sample_libsvm_data.txt";
-    String algo = "Classification";
-    if (args.length >= 1) {
-      datapath = args[0];
-    }
-    if (args.length >= 2) {
-      algo = args[1];
-    }
-    if (args.length > 2) {
-      usage();
-    }
-    SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
-    JavaSparkContext sc = new JavaSparkContext(sparkConf);
-
-    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
-
-    // Set parameters.
-    //  Note: All features are treated as continuous.
-    BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
-    boostingStrategy.setNumIterations(10);
-    boostingStrategy.weakLearnerParams().setMaxDepth(5);
-
-    if (algo.equals("Classification")) {
-      // Compute the number of classes from the data.
-      Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
-        @Override public Double call(LabeledPoint p) {
-          return p.label();
-        }
-      }).countByValue().size();
-      boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
-
-      // Train a GradientBoosting model for classification.
-      final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
-
-      // Evaluate model on training instances and compute training error
-      JavaPairRDD<Double, Double> predictionAndLabel =
-          data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
-            @Override public Tuple2<Double, Double> call(LabeledPoint p) {
-              return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
-            }
-          });
-      Double trainErr =
-          1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
-            @Override public Boolean call(Tuple2<Double, Double> pl) {
-              return !pl._1().equals(pl._2());
-            }
-          }).count() / data.count();
-      System.out.println("Training error: " + trainErr);
-      System.out.println("Learned classification tree model:\n" + model);
-    } else if (algo.equals("Regression")) {
-      // Train a GradientBoosting model for classification.
-      final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
-
-      // Evaluate model on training instances and compute training error
-      JavaPairRDD<Double, Double> predictionAndLabel =
-          data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
-            @Override public Tuple2<Double, Double> call(LabeledPoint p) {
-              return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
-            }
-          });
-      Double trainMSE =
-          predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
-            @Override public Double call(Tuple2<Double, Double> pl) {
-              Double diff = pl._1() - pl._2();
-              return diff * diff;
-            }
-          }).reduce(new Function2<Double, Double, Double>() {
-            @Override public Double call(Double a, Double b) {
-              return a + b;
-            }
-          }) / data.count();
-      System.out.println("Training Mean Squared Error: " + trainMSE);
-      System.out.println("Learned regression tree model:\n" + model);
-    } else {
-      usage();
-    }
-
-    sc.stop();
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
new file mode 100644
index 0000000..4a5ac40
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple2;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.GradientBoostedTrees;
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
+import org.apache.spark.mllib.util.MLUtils;
+
+/**
+ * Classification and regression using gradient-boosted decision trees.
+ */
+public final class JavaGradientBoostedTreesRunner {
+
+  private static void usage() {
+    System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" +
+        " <Classification/Regression>");
+    System.exit(-1);
+  }
+
+  public static void main(String[] args) {
+    String datapath = "data/mllib/sample_libsvm_data.txt";
+    String algo = "Classification";
+    if (args.length >= 1) {
+      datapath = args[0];
+    }
+    if (args.length >= 2) {
+      algo = args[1];
+    }
+    if (args.length > 2) {
+      usage();
+    }
+    SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
+    JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
+
+    // Set parameters.
+    //  Note: All features are treated as continuous.
+    BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
+    boostingStrategy.setNumIterations(10);
+    boostingStrategy.treeStrategy().setMaxDepth(5);
+
+    if (algo.equals("Classification")) {
+      // Compute the number of classes from the data.
+      Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
+        @Override public Double call(LabeledPoint p) {
+          return p.label();
+        }
+      }).countByValue().size();
+      boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
+
+      // Train a GradientBoosting model for classification.
+      final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
+
+      // Evaluate model on training instances and compute training error
+      JavaPairRDD<Double, Double> predictionAndLabel =
+          data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+            @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+              return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+            }
+          });
+      Double trainErr =
+          1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
+            @Override public Boolean call(Tuple2<Double, Double> pl) {
+              return !pl._1().equals(pl._2());
+            }
+          }).count() / data.count();
+      System.out.println("Training error: " + trainErr);
+      System.out.println("Learned classification tree model:\n" + model);
+    } else if (algo.equals("Regression")) {
+      // Train a GradientBoosting model for classification.
+      final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
+
+      // Evaluate model on training instances and compute training error
+      JavaPairRDD<Double, Double> predictionAndLabel =
+          data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+            @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+              return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+            }
+          });
+      Double trainMSE =
+          predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
+            @Override public Double call(Tuple2<Double, Double> pl) {
+              Double diff = pl._1() - pl._2();
+              return diff * diff;
+            }
+          }).reduce(new Function2<Double, Double, Double>() {
+            @Override public Double call(Double a, Double b) {
+              return a + b;
+            }
+          }) / data.count();
+      System.out.println("Training Mean Squared Error: " + trainMSE);
+      System.out.println("Learned regression tree model:\n" + model);
+    } else {
+      usage();
+    }
+
+    sc.stop();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 63f02cf..98f9d16 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -22,11 +22,11 @@ import scopt.OptionParser
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.SparkContext._
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
+import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity}
 import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
 import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.Utils
@@ -352,21 +352,11 @@ object DecisionTreeRunner {
   /**
    * Calculates the mean squared error for regression.
    */
-  private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
-    data.map { y =>
-      val err = tree.predict(y.features) - y.label
-      err * err
-    }.mean()
-  }
-
-  /**
-   * Calculates the mean squared error for regression.
-   */
   private[mllib] def meanSquaredError(
-      tree: WeightedEnsembleModel,
+      model: { def predict(features: Vector): Double },
       data: RDD[LabeledPoint]): Double = {
     data.map { y =>
-      val err = tree.predict(y.features) - y.label
+      val err = model.predict(y.features) - y.label
       err * err
     }.mean()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
deleted file mode 100644
index 9b6db01..0000000
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.mllib
-
-import scopt.OptionParser
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.tree.GradientBoosting
-import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
-import org.apache.spark.util.Utils
-
-/**
- * An example runner for Gradient Boosting using decision trees as weak learners. Run with
- * {{{
- * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
- * }}}
- * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
- *
- * Note: This script treats all features as real-valued (not categorical).
- *       To include categorical features, modify categoricalFeaturesInfo.
- */
-object GradientBoostedTrees {
-
-  case class Params(
-      input: String = null,
-      testInput: String = "",
-      dataFormat: String = "libsvm",
-      algo: String = "Classification",
-      maxDepth: Int = 5,
-      numIterations: Int = 10,
-      fracTest: Double = 0.2) extends AbstractParams[Params]
-
-  def main(args: Array[String]) {
-    val defaultParams = Params()
-
-    val parser = new OptionParser[Params]("GradientBoostedTrees") {
-      head("GradientBoostedTrees: an example decision tree app.")
-      opt[String]("algo")
-        .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
-        .action((x, c) => c.copy(algo = x))
-      opt[Int]("maxDepth")
-        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
-        .action((x, c) => c.copy(maxDepth = x))
-      opt[Int]("numIterations")
-        .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
-        .action((x, c) => c.copy(numIterations = x))
-      opt[Double]("fracTest")
-        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
-          s"this option is ignored. default: ${defaultParams.fracTest}")
-        .action((x, c) => c.copy(fracTest = x))
-      opt[String]("testInput")
-        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
-          s" default: ${defaultParams.testInput}")
-        .action((x, c) => c.copy(testInput = x))
-      opt[String]("<dataFormat>")
-        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
-        .action((x, c) => c.copy(dataFormat = x))
-      arg[String]("<input>")
-        .text("input path to labeled examples")
-        .required()
-        .action((x, c) => c.copy(input = x))
-      checkConfig { params =>
-        if (params.fracTest < 0 || params.fracTest > 1) {
-          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
-        } else {
-          success
-        }
-      }
-    }
-
-    parser.parse(args, defaultParams).map { params =>
-      run(params)
-    }.getOrElse {
-      sys.exit(1)
-    }
-  }
-
-  def run(params: Params) {
-
-    val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
-    val sc = new SparkContext(conf)
-
-    println(s"GradientBoostedTrees with parameters:\n$params")
-
-    // Load training and test data and cache it.
-    val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
-      params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
-
-    val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
-    boostingStrategy.numClassesForClassification = numClasses
-    boostingStrategy.numIterations = params.numIterations
-    boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
-
-    val randomSeed = Utils.random.nextInt()
-    if (params.algo == "Classification") {
-      val startTime = System.nanoTime()
-      val model = GradientBoosting.trainClassifier(training, boostingStrategy)
-      val elapsedTime = (System.nanoTime() - startTime) / 1e9
-      println(s"Training time: $elapsedTime seconds")
-      if (model.totalNumNodes < 30) {
-        println(model.toDebugString) // Print full model.
-      } else {
-        println(model) // Print model summary.
-      }
-      val trainAccuracy =
-        new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
-          .precision
-      println(s"Train accuracy = $trainAccuracy")
-      val testAccuracy =
-        new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
-      println(s"Test accuracy = $testAccuracy")
-    } else if (params.algo == "Regression") {
-      val startTime = System.nanoTime()
-      val model = GradientBoosting.trainRegressor(training, boostingStrategy)
-      val elapsedTime = (System.nanoTime() - startTime) / 1e9
-      println(s"Training time: $elapsedTime seconds")
-      if (model.totalNumNodes < 30) {
-        println(model.toDebugString) // Print full model.
-      } else {
-        println(model) // Print model summary.
-      }
-      val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
-      println(s"Train mean squared error = $trainMSE")
-      val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
-      println(s"Test mean squared error = $testMSE")
-    }
-
-    sc.stop()
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
new file mode 100644
index 0000000..1def8b4
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.tree.GradientBoostedTrees
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
+import org.apache.spark.util.Utils
+
+/**
+ * An example runner for Gradient Boosting using decision trees as weak learners. Run with
+ * {{{
+ * ./bin/run-example mllib.GradientBoostedTreesRunner [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ *       To include categorical features, modify categoricalFeaturesInfo.
+ */
+object GradientBoostedTreesRunner {
+
+  case class Params(
+      input: String = null,
+      testInput: String = "",
+      dataFormat: String = "libsvm",
+      algo: String = "Classification",
+      maxDepth: Int = 5,
+      numIterations: Int = 10,
+      fracTest: Double = 0.2) extends AbstractParams[Params]
+
+  def main(args: Array[String]) {
+    val defaultParams = Params()
+
+    val parser = new OptionParser[Params]("GradientBoostedTrees") {
+      head("GradientBoostedTrees: an example decision tree app.")
+      opt[String]("algo")
+        .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
+        .action((x, c) => c.copy(algo = x))
+      opt[Int]("maxDepth")
+        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+        .action((x, c) => c.copy(maxDepth = x))
+      opt[Int]("numIterations")
+        .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
+        .action((x, c) => c.copy(numIterations = x))
+      opt[Double]("fracTest")
+        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
+          s"this option is ignored. default: ${defaultParams.fracTest}")
+        .action((x, c) => c.copy(fracTest = x))
+      opt[String]("testInput")
+        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
+          s" default: ${defaultParams.testInput}")
+        .action((x, c) => c.copy(testInput = x))
+      opt[String]("<dataFormat>")
+        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+        .action((x, c) => c.copy(dataFormat = x))
+      arg[String]("<input>")
+        .text("input path to labeled examples")
+        .required()
+        .action((x, c) => c.copy(input = x))
+      checkConfig { params =>
+        if (params.fracTest < 0 || params.fracTest > 1) {
+          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+        } else {
+          success
+        }
+      }
+    }
+
+    parser.parse(args, defaultParams).map { params =>
+      run(params)
+    }.getOrElse {
+      sys.exit(1)
+    }
+  }
+
+  def run(params: Params) {
+
+    val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params")
+    val sc = new SparkContext(conf)
+
+    println(s"GradientBoostedTreesRunner with parameters:\n$params")
+
+    // Load training and test data and cache it.
+    val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
+      params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
+
+    val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
+    boostingStrategy.treeStrategy.numClassesForClassification = numClasses
+    boostingStrategy.numIterations = params.numIterations
+    boostingStrategy.treeStrategy.maxDepth = params.maxDepth
+
+    val randomSeed = Utils.random.nextInt()
+    if (params.algo == "Classification") {
+      val startTime = System.nanoTime()
+      val model = GradientBoostedTrees.train(training, boostingStrategy)
+      val elapsedTime = (System.nanoTime() - startTime) / 1e9
+      println(s"Training time: $elapsedTime seconds")
+      if (model.totalNumNodes < 30) {
+        println(model.toDebugString) // Print full model.
+      } else {
+        println(model) // Print model summary.
+      }
+      val trainAccuracy =
+        new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+          .precision
+      println(s"Train accuracy = $trainAccuracy")
+      val testAccuracy =
+        new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+      println(s"Test accuracy = $testAccuracy")
+    } else if (params.algo == "Regression") {
+      val startTime = System.nanoTime()
+      val model = GradientBoostedTrees.train(training, boostingStrategy)
+      val elapsedTime = (System.nanoTime() - startTime) / 1e9
+      println(s"Training time: $elapsedTime seconds")
+      if (model.totalNumNodes < 30) {
+        println(model.toDebugString) // Print full model.
+      } else {
+        println(model) // Print model summary.
+      }
+      val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
+      println(s"Train mean squared error = $trainMSE")
+      val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
+      println(s"Test mean squared error = $testMSE")
+    }
+
+    sc.stop()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 78acc17..3d91867 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -58,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
    * @return DecisionTreeModel that can be used for prediction
    */
-  def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
+  def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
     // Note: random seed will not be used since numTrees = 1.
     val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
-    val rfModel = rf.train(input)
-    rfModel.weakHypotheses(0)
+    val rfModel = rf.run(input)
+    rfModel.trees(0)
   }
 
+  /**
+   * Trains a decision tree model over an RDD. This is deprecated because it hides the static
+   * methods with the same name in Java.
+   */
+  @deprecated("Please use DecisionTree.run instead.", "1.2.0")
+  def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input)
 }
 
 object DecisionTree extends Serializable with Logging {
@@ -86,7 +92,7 @@ object DecisionTree extends Serializable with Logging {
    * @return DecisionTreeModel that can be used for prediction
   */
   def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
-    new DecisionTree(strategy).train(input)
+    new DecisionTree(strategy).run(input)
   }
 
   /**
@@ -112,7 +118,7 @@ object DecisionTree extends Serializable with Logging {
       impurity: Impurity,
       maxDepth: Int): DecisionTreeModel = {
     val strategy = new Strategy(algo, impurity, maxDepth)
-    new DecisionTree(strategy).train(input)
+    new DecisionTree(strategy).run(input)
   }
 
   /**
@@ -140,7 +146,7 @@ object DecisionTree extends Serializable with Logging {
       maxDepth: Int,
       numClassesForClassification: Int): DecisionTreeModel = {
     val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
-    new DecisionTree(strategy).train(input)
+    new DecisionTree(strategy).run(input)
   }
 
   /**
@@ -177,7 +183,7 @@ object DecisionTree extends Serializable with Logging {
       categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
     val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
       quantileCalculationStrategy, categoricalFeaturesInfo)
-    new DecisionTree(strategy).train(input)
+    new DecisionTree(strategy).run(input)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
new file mode 100644
index 0000000..cb4ddfc
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.impurity.Variance
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * :: Experimental ::
+ * A class that implements Stochastic Gradient Boosting for regression and binary classification.
+ *
+ * The implementation is based upon:
+ *   J.H. Friedman.  "Stochastic Gradient Boosting."  1999.
+ *
+ * Notes:
+ *  - This currently can be run with several loss functions.  However, only SquaredError is
+ *    fully supported.  Specifically, the loss function should be used to compute the gradient
+ *    (to re-label training instances on each iteration) and to weight weak hypotheses.
+ *    Currently, gradients are computed correctly for the available loss functions,
+ *    but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
+ *    Running with those losses will likely behave reasonably, but lacks the same guarantees.
+ *
+ * @param boostingStrategy Parameters for the gradient boosting algorithm.
+ */
+@Experimental
+class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
+  extends Serializable with Logging {
+
+  /**
+   * Method to train a gradient boosting model
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   * @return a gradient boosted trees model that can be used for prediction
+   */
+  def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
+    val algo = boostingStrategy.treeStrategy.algo
+    algo match {
+      case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
+      case Classification =>
+        // Map labels to -1, +1 so binary classification can be treated as regression.
+        val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+        GradientBoostedTrees.boost(remappedInput, boostingStrategy)
+      case _ =>
+        throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+    }
+  }
+
+  /**
+   * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
+   */
+  def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
+    run(input.rdd)
+  }
+}
+
+
+object GradientBoostedTrees extends Logging {
+
+  /**
+   * Method to train a gradient boosting model.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
+   *              For regression, labels are real numbers.
+   * @param boostingStrategy Configuration options for the boosting algorithm.
+   * @return a gradient boosted trees model that can be used for prediction
+   */
+  def train(
+      input: RDD[LabeledPoint],
+      boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+    new GradientBoostedTrees(boostingStrategy).run(input)
+  }
+
+  /**
+   * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]]
+   */
+  def train(
+      input: JavaRDD[LabeledPoint],
+      boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+    train(input.rdd, boostingStrategy)
+  }
+
+  /**
+   * Internal method for performing regression using trees as base learners.
+   * @param input training dataset
+   * @param boostingStrategy boosting parameters
+   * @return a gradient boosted trees model that can be used for prediction
+   */
+  private def boost(
+      input: RDD[LabeledPoint],
+      boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+
+    val timer = new TimeTracker()
+    timer.start("total")
+    timer.start("init")
+
+    boostingStrategy.assertValid()
+
+    // Initialize gradient boosting parameters
+    val numIterations = boostingStrategy.numIterations
+    val baseLearners = new Array[DecisionTreeModel](numIterations)
+    val baseLearnerWeights = new Array[Double](numIterations)
+    val loss = boostingStrategy.loss
+    val learningRate = boostingStrategy.learningRate
+    // Prepare strategy for individual trees, which use regression with variance impurity.
+    val treeStrategy = boostingStrategy.treeStrategy.copy
+    treeStrategy.algo = Regression
+    treeStrategy.impurity = Variance
+    treeStrategy.assertValid()
+
+    // Cache input
+    if (input.getStorageLevel == StorageLevel.NONE) {
+      input.persist(StorageLevel.MEMORY_AND_DISK)
+    }
+
+    timer.stop("init")
+
+    logDebug("##########")
+    logDebug("Building tree 0")
+    logDebug("##########")
+    var data = input
+
+    // Initialize tree
+    timer.start("building tree 0")
+    val firstTreeModel = new DecisionTree(treeStrategy).run(data)
+    baseLearners(0) = firstTreeModel
+    baseLearnerWeights(0) = 1.0
+    val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
+    logDebug("error of gbt = " + loss.computeError(startingModel, input))
+    // Note: A model of type regression is used since we require raw prediction
+    timer.stop("building tree 0")
+
+    // psuedo-residual for second iteration
+    data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
+      point.features))
+
+    var m = 1
+    while (m < numIterations) {
+      timer.start(s"building tree $m")
+      logDebug("###################################################")
+      logDebug("Gradient boosting tree iteration " + m)
+      logDebug("###################################################")
+      val model = new DecisionTree(treeStrategy).run(data)
+      timer.stop(s"building tree $m")
+      // Create partial model
+      baseLearners(m) = model
+      // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+      //       Technically, the weight should be optimized for the particular loss.
+      //       However, the behavior should be reasonable, though not optimal.
+      baseLearnerWeights(m) = learningRate
+      // Note: A model of type regression is used since we require raw prediction
+      val partialModel = new GradientBoostedTreesModel(
+        Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
+      logDebug("error of gbt = " + loss.computeError(partialModel, input))
+      // Update data with pseudo-residuals
+      data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
+        point.features))
+      m += 1
+    }
+
+    timer.stop("total")
+
+    logInfo("Internal timing for DecisionTree:")
+    logInfo(s"$timer")
+
+    new GradientBoostedTreesModel(
+      boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
deleted file mode 100644
index f729344..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
+++ /dev/null
@@ -1,249 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree
-
-import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.configuration.BoostingStrategy
-import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
-import org.apache.spark.mllib.tree.impl.TimeTracker
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-
-/**
- * :: Experimental ::
- * A class that implements Stochastic Gradient Boosting
- * for regression and binary classification problems.
- *
- * The implementation is based upon:
- *   J.H. Friedman.  "Stochastic Gradient Boosting."  1999.
- *
- * Notes:
- *  - This currently can be run with several loss functions.  However, only SquaredError is
- *    fully supported.  Specifically, the loss function should be used to compute the gradient
- *    (to re-label training instances on each iteration) and to weight weak hypotheses.
- *    Currently, gradients are computed correctly for the available loss functions,
- *    but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
- *    Running with those losses will likely behave reasonably, but lacks the same guarantees.
- *
- * @param boostingStrategy Parameters for the gradient boosting algorithm
- */
-@Experimental
-class GradientBoosting (
-    private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
-
-  boostingStrategy.weakLearnerParams.algo = Regression
-  boostingStrategy.weakLearnerParams.impurity = impurity.Variance
-
-  // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
-  boostingStrategy.weakLearnerParams.numClassesForClassification =
-    boostingStrategy.numClassesForClassification
-
-  boostingStrategy.assertValid()
-
-  /**
-   * Method to train a gradient boosting model
-   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-   * @return WeightedEnsembleModel that can be used for prediction
-   */
-  def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
-    val algo = boostingStrategy.algo
-    algo match {
-      case Regression => GradientBoosting.boost(input, boostingStrategy)
-      case Classification =>
-        // Map labels to -1, +1 so binary classification can be treated as regression.
-        val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
-        GradientBoosting.boost(remappedInput, boostingStrategy)
-      case _ =>
-        throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
-    }
-  }
-
-}
-
-
-object GradientBoosting extends Logging {
-
-  /**
-   * Method to train a gradient boosting model.
-   *
-   * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
-   *       is recommended to clearly specify regression.
-   *       Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
-   *       is recommended to clearly specify regression.
-   *
-   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
-   *              For regression, labels are real numbers.
-   * @param boostingStrategy Configuration options for the boosting algorithm.
-   * @return WeightedEnsembleModel that can be used for prediction
-   */
-  def train(
-      input: RDD[LabeledPoint],
-      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
-    new GradientBoosting(boostingStrategy).train(input)
-  }
-
-  /**
-   * Method to train a gradient boosting classification model.
-   *
-   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
-   *              For regression, labels are real numbers.
-   * @param boostingStrategy Configuration options for the boosting algorithm.
-   * @return WeightedEnsembleModel that can be used for prediction
-   */
-  def trainClassifier(
-      input: RDD[LabeledPoint],
-      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
-    val algo = boostingStrategy.algo
-    require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
-    new GradientBoosting(boostingStrategy).train(input)
-  }
-
-  /**
-   * Method to train a gradient boosting regression model.
-   *
-   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
-   *              For regression, labels are real numbers.
-   * @param boostingStrategy Configuration options for the boosting algorithm.
-   * @return WeightedEnsembleModel that can be used for prediction
-   */
-  def trainRegressor(
-      input: RDD[LabeledPoint],
-      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
-    val algo = boostingStrategy.algo
-    require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
-    new GradientBoosting(boostingStrategy).train(input)
-  }
-
-  /**
-   * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]]
-   */
-  def train(
-    input: JavaRDD[LabeledPoint],
-    boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
-    train(input.rdd, boostingStrategy)
-  }
-
-  /**
-   * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
-   */
-  def trainClassifier(
-      input: JavaRDD[LabeledPoint],
-      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
-    trainClassifier(input.rdd, boostingStrategy)
-  }
-
-  /**
-   * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
-   */
-  def trainRegressor(
-      input: JavaRDD[LabeledPoint],
-      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
-    trainRegressor(input.rdd, boostingStrategy)
-  }
-
-  /**
-   * Internal method for performing regression using trees as base learners.
-   * @param input training dataset
-   * @param boostingStrategy boosting parameters
-   * @return
-   */
-  private def boost(
-      input: RDD[LabeledPoint],
-      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
-
-    val timer = new TimeTracker()
-    timer.start("total")
-    timer.start("init")
-
-    // Initialize gradient boosting parameters
-    val numIterations = boostingStrategy.numIterations
-    val baseLearners = new Array[DecisionTreeModel](numIterations)
-    val baseLearnerWeights = new Array[Double](numIterations)
-    val loss = boostingStrategy.loss
-    val learningRate = boostingStrategy.learningRate
-    val strategy = boostingStrategy.weakLearnerParams
-
-    // Cache input
-    if (input.getStorageLevel == StorageLevel.NONE) {
-      input.persist(StorageLevel.MEMORY_AND_DISK)
-    }
-
-    timer.stop("init")
-
-    logDebug("##########")
-    logDebug("Building tree 0")
-    logDebug("##########")
-    var data = input
-
-    // Initialize tree
-    timer.start("building tree 0")
-    val firstTreeModel = new DecisionTree(strategy).train(data)
-    baseLearners(0) = firstTreeModel
-    baseLearnerWeights(0) = 1.0
-    val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
-      Sum)
-    logDebug("error of gbt = " + loss.computeError(startingModel, input))
-    // Note: A model of type regression is used since we require raw prediction
-    timer.stop("building tree 0")
-
-    // psuedo-residual for second iteration
-    data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
-      point.features))
-
-    var m = 1
-    while (m < numIterations) {
-      timer.start(s"building tree $m")
-      logDebug("###################################################")
-      logDebug("Gradient boosting tree iteration " + m)
-      logDebug("###################################################")
-      val model = new DecisionTree(strategy).train(data)
-      timer.stop(s"building tree $m")
-      // Create partial model
-      baseLearners(m) = model
-      // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
-      //       Technically, the weight should be optimized for the particular loss.
-      //       However, the behavior should be reasonable, though not optimal.
-      baseLearnerWeights(m) = learningRate
-      // Note: A model of type regression is used since we require raw prediction
-      val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
-        baseLearnerWeights.slice(0, m + 1), Regression, Sum)
-      logDebug("error of gbt = " + loss.computeError(partialModel, input))
-      // Update data with pseudo-residuals
-      data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
-        point.features))
-      m += 1
-    }
-
-    timer.stop("total")
-
-    logInfo("Internal timing for DecisionTree:")
-    logInfo(s"$timer")
-
-    new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
-
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 9683916..ca0b6ee 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -17,18 +17,18 @@
 
 package org.apache.spark.mllib.tree
 
-import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.collection.JavaConverters._
 
 import org.apache.spark.Logging
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
-import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache,
+  TimeTracker, TreePoint}
 import org.apache.spark.mllib.tree.impurity.Impurities
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
@@ -79,9 +79,9 @@ private class RandomForest (
   /**
    * Method to train a decision tree model over an RDD
    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
-   * @return WeightedEnsembleModel that can be used for prediction
+   * @return a random forest model that can be used for prediction
    */
-  def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
+  def run(input: RDD[LabeledPoint]): RandomForestModel = {
 
     val timer = new TimeTracker()
 
@@ -212,8 +212,7 @@ private class RandomForest (
     }
 
     val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
-    val treeWeights = Array.fill[Double](numTrees)(1.0)
-    new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
+    new RandomForestModel(strategy.algo, trees)
   }
 
 }
@@ -234,18 +233,18 @@ object RandomForest extends Serializable with Logging {
    *                                if numTrees > 1 (forest) set to "sqrt" for classification and
    *                                  to "onethird" for regression.
    * @param seed  Random seed for bootstrapping and choosing feature subsets.
-   * @return WeightedEnsembleModel that can be used for prediction
+   * @return a random forest model that can be used for prediction
    */
   def trainClassifier(
       input: RDD[LabeledPoint],
       strategy: Strategy,
       numTrees: Int,
       featureSubsetStrategy: String,
-      seed: Int): WeightedEnsembleModel = {
+      seed: Int): RandomForestModel = {
     require(strategy.algo == Classification,
       s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
     val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
-    rf.train(input)
+    rf.run(input)
   }
 
   /**
@@ -272,7 +271,7 @@ object RandomForest extends Serializable with Logging {
    * @param maxBins maximum number of bins used for splitting features
    *                 (suggested value: 100)
    * @param seed  Random seed for bootstrapping and choosing feature subsets.
-   * @return WeightedEnsembleModel that can be used for prediction
+   * @return a random forest model  that can be used for prediction
    */
   def trainClassifier(
       input: RDD[LabeledPoint],
@@ -283,7 +282,7 @@ object RandomForest extends Serializable with Logging {
       impurity: String,
       maxDepth: Int,
       maxBins: Int,
-      seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
+      seed: Int = Utils.random.nextInt()): RandomForestModel = {
     val impurityType = Impurities.fromString(impurity)
     val strategy = new Strategy(Classification, impurityType, maxDepth,
       numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
@@ -302,7 +301,7 @@ object RandomForest extends Serializable with Logging {
       impurity: String,
       maxDepth: Int,
       maxBins: Int,
-      seed: Int): WeightedEnsembleModel = {
+      seed: Int): RandomForestModel = {
     trainClassifier(input.rdd, numClassesForClassification,
       categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
       numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -322,18 +321,18 @@ object RandomForest extends Serializable with Logging {
    *                                if numTrees > 1 (forest) set to "sqrt" for classification and
    *                                  to "onethird" for regression.
    * @param seed  Random seed for bootstrapping and choosing feature subsets.
-   * @return WeightedEnsembleModel that can be used for prediction
+   * @return a random forest model that can be used for prediction
    */
   def trainRegressor(
       input: RDD[LabeledPoint],
       strategy: Strategy,
       numTrees: Int,
       featureSubsetStrategy: String,
-      seed: Int): WeightedEnsembleModel = {
+      seed: Int): RandomForestModel = {
     require(strategy.algo == Regression,
       s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
     val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
-    rf.train(input)
+    rf.run(input)
   }
 
   /**
@@ -359,7 +358,7 @@ object RandomForest extends Serializable with Logging {
    * @param maxBins maximum number of bins used for splitting features
    *                 (suggested value: 100)
    * @param seed  Random seed for bootstrapping and choosing feature subsets.
-   * @return WeightedEnsembleModel that can be used for prediction
+   * @return a random forest model that can be used for prediction
    */
   def trainRegressor(
       input: RDD[LabeledPoint],
@@ -369,7 +368,7 @@ object RandomForest extends Serializable with Logging {
       impurity: String,
       maxDepth: Int,
       maxBins: Int,
-      seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
+      seed: Int = Utils.random.nextInt()): RandomForestModel = {
     val impurityType = Impurities.fromString(impurity)
     val strategy = new Strategy(Regression, impurityType, maxDepth,
       0, maxBins, Sort, categoricalFeaturesInfo)
@@ -387,7 +386,7 @@ object RandomForest extends Serializable with Logging {
       impurity: String,
       maxDepth: Int,
       maxBins: Int,
-      seed: Int): WeightedEnsembleModel = {
+      seed: Int): RandomForestModel = {
     trainRegressor(input.rdd,
       categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
       numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -479,5 +478,4 @@ object RandomForest extends Serializable with Logging {
       3 * totalBins
     }
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index abbda04..e703adb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -25,57 +25,39 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
 
 /**
  * :: Experimental ::
- * Stores all the configuration options for the boosting algorithms
- * @param algo  Learning goal.  Supported:
- *              [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
- *              [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]].
+ *
+ * @param treeStrategy Parameters for the tree algorithm. We support regression and binary
+ *                     classification for boosting. Impurity setting will be ignored.
+ * @param loss Loss function used for minimization during gradient boosting.
  * @param numIterations Number of iterations of boosting.  In other words, the number of
  *                      weak hypotheses used in the final model.
- * @param loss Loss function used for minimization during gradient boosting.
  * @param learningRate Learning rate for shrinking the contribution of each estimator. The
  *                     learning rate should be between in the interval (0, 1]
- * @param numClassesForClassification Number of classes for classification.
- *                                    (Ignored for regression.)
- *                                    This setting overrides any setting in [[weakLearnerParams]].
- *                                    Default value is 2 (binary classification).
- * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
- *                          supported.
  */
 @Experimental
 case class BoostingStrategy(
     // Required boosting parameters
-    @BeanProperty var algo: Algo,
-    @BeanProperty var numIterations: Int,
+    @BeanProperty var treeStrategy: Strategy,
     @BeanProperty var loss: Loss,
     // Optional boosting parameters
-    @BeanProperty var learningRate: Double = 0.1,
-    @BeanProperty var numClassesForClassification: Int = 2,
-    @BeanProperty var weakLearnerParams: Strategy) extends Serializable {
-
-  // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
-  weakLearnerParams.numClassesForClassification = numClassesForClassification
-
-  /**
-   * Sets Algorithm using a String.
-   */
-  def setAlgo(algo: String): Unit = algo match {
-    case "Classification" => setAlgo(Classification)
-    case "Regression" => setAlgo(Regression)
-  }
+    @BeanProperty var numIterations: Int = 100,
+    @BeanProperty var learningRate: Double = 0.1) extends Serializable {
 
   /**
    * Check validity of parameters.
    * Throws exception if invalid.
    */
   private[tree] def assertValid(): Unit = {
-    algo match {
+    treeStrategy.algo match {
       case Classification =>
-        require(numClassesForClassification == 2)
+        require(treeStrategy.numClassesForClassification == 2,
+          "Only binary classification is supported for boosting.")
       case Regression =>
         // nothing
       case _ =>
         throw new IllegalArgumentException(
-          s"BoostingStrategy given invalid algo parameter: $algo." +
+          s"BoostingStrategy given invalid algo parameter: ${treeStrategy.algo}." +
             s"  Valid settings are: Classification, Regression.")
     }
     require(learningRate > 0 && learningRate <= 1,
@@ -94,14 +76,14 @@ object BoostingStrategy {
    * @return Configuration for boosting algorithm
    */
   def defaultParams(algo: String): BoostingStrategy = {
-    val treeStrategy = Strategy.defaultStrategy("Regression")
+    val treeStrategy = Strategy.defaultStrategy(algo)
     treeStrategy.maxDepth = 3
     algo match {
       case "Classification" =>
-        new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy)
+        treeStrategy.numClassesForClassification = 2
+        new BoostingStrategy(treeStrategy, LogLoss)
       case "Regression" =>
-        new BoostingStrategy(Algo.withName(algo), 100, SquaredError,
-          weakLearnerParams = treeStrategy)
+        new BoostingStrategy(treeStrategy, SquaredError)
       case _ =>
         throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
index 82889dc..b5bf732 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
@@ -17,14 +17,10 @@
 
 package org.apache.spark.mllib.tree.configuration
 
-import org.apache.spark.annotation.DeveloperApi
-
 /**
- * :: Experimental ::
  * Enum to select ensemble combining strategy for base learners
  */
-@DeveloperApi
-object EnsembleCombiningStrategy extends Enumeration {
+private[tree] object EnsembleCombiningStrategy extends Enumeration {
   type EnsembleCombiningStrategy = Value
-  val Sum, Average = Value
+  val Average, Sum, Vote = Value
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index b5b1f82..d75f384 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -157,6 +157,13 @@ class Strategy (
     require(maxMemoryInMB <= 10240,
       s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
   }
+
+  /** Returns a shallow copy of this instance. */
+  def copy: Strategy = {
+    new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
+      quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
+      maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
+  }
 }
 
 @Experimental

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index d111ffe..e828866 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
 import org.apache.spark.rdd.RDD
 
 /**
@@ -42,7 +42,7 @@ object AbsoluteError extends Loss {
    * @return Loss gradient
    */
   override def gradient(
-      model: WeightedEnsembleModel,
+      model: TreeEnsembleModel,
       point: LabeledPoint): Double = {
     if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
   }
@@ -55,7 +55,7 @@ object AbsoluteError extends Loss {
    * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    * @return
    */
-  override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+  override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
     val sumOfAbsolutes = data.map { y =>
       val err = model.predict(y.features) - y.label
       math.abs(err)

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 6f3d434..8b8adb4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
 import org.apache.spark.rdd.RDD
 
 /**
@@ -42,7 +42,7 @@ object LogLoss extends Loss {
    * @return Loss gradient
    */
   override def gradient(
-      model: WeightedEnsembleModel,
+      model: TreeEnsembleModel,
       point: LabeledPoint): Double = {
     val prediction = model.predict(point.features)
     1.0 / (1.0 + math.exp(-prediction)) - point.label
@@ -56,7 +56,7 @@ object LogLoss extends Loss {
    * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    * @return
    */
-  override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+  override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
     val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count()
     wrongPredictions / data.count
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index 5580866..4bca903 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
 import org.apache.spark.rdd.RDD
 
 /**
@@ -36,7 +36,7 @@ trait Loss extends Serializable {
    * @return Loss gradient.
    */
   def gradient(
-      model: WeightedEnsembleModel,
+      model: TreeEnsembleModel,
       point: LabeledPoint): Double
 
   /**
@@ -47,6 +47,6 @@ trait Loss extends Serializable {
    * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    * @return
    */
-  def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double
+  def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 4349fef..cfe395b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
 import org.apache.spark.rdd.RDD
 
 /**
@@ -43,7 +43,7 @@ object SquaredError extends Loss {
    * @return Loss gradient
    */
   override def gradient(
-    model: WeightedEnsembleModel,
+    model: TreeEnsembleModel,
     point: LabeledPoint): Double = {
     model.predict(point.features) - point.label
   }
@@ -56,7 +56,7 @@ object SquaredError extends Loss {
    * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    * @return
    */
-  override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+  override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
     data.map { y =>
       val err = model.predict(y.features) - y.label
       err * err

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index ac4d02e..a576096 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,11 +17,11 @@
 
 package org.apache.spark.mllib.tree.model
 
-import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.Vector
 
 /**
  * :: Experimental ::

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
deleted file mode 100644
index 7b052d9..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
+++ /dev/null
@@ -1,158 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
-import org.apache.spark.rdd.RDD
-
-import scala.collection.mutable
-
-@Experimental
-class WeightedEnsembleModel(
-    val weakHypotheses: Array[DecisionTreeModel],
-    val weakHypothesisWeights: Array[Double],
-    val algo: Algo,
-    val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
-
-  require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" +
-    s". Number of weakHypotheses = $weakHypotheses")
-
-  /**
-   * Predict values for a single data point using the model trained.
-   *
-   * @param features array representing a single data point
-   * @return predicted category from the trained model
-   */
-  private def predictRaw(features: Vector): Double = {
-    val treePredictions = weakHypotheses.map(learner => learner.predict(features))
-    if (numWeakHypotheses == 1){
-      treePredictions(0)
-    } else {
-      var prediction = treePredictions(0)
-      var index = 1
-      while (index < numWeakHypotheses) {
-        prediction += weakHypothesisWeights(index) * treePredictions(index)
-        index += 1
-      }
-      prediction
-    }
-  }
-
-  /**
-   * Predict values for a single data point using the model trained.
-   *
-   * @param features array representing a single data point
-   * @return predicted category from the trained model
-   */
-  private def predictBySumming(features: Vector): Double = {
-    algo match {
-      case Regression => predictRaw(features)
-      case Classification => {
-        // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
-        if (predictRaw(features) > 0 ) 1.0 else 0.0
-      }
-      case _ => throw new IllegalArgumentException(
-        s"WeightedEnsembleModel given unknown algo parameter: $algo.")
-    }
-  }
-
-  /**
-   * Predict values for a single data point.
-   *
-   * @param features array representing a single data point
-   * @return Double prediction from the trained model
-   */
-  private def predictByAveraging(features: Vector): Double = {
-    algo match {
-      case Classification =>
-        val predictionToCount = new mutable.HashMap[Int, Int]()
-        weakHypotheses.foreach { learner =>
-          val prediction = learner.predict(features).toInt
-          predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
-        }
-        predictionToCount.maxBy(_._2)._1
-      case Regression =>
-        weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size
-    }
-  }
-
-
-  /**
-   * Predict values for a single data point using the model trained.
-   *
-   * @param features array representing a single data point
-   * @return predicted category from the trained model
-   */
-  def predict(features: Vector): Double = {
-    combiningStrategy match {
-      case Sum => predictBySumming(features)
-      case Average => predictByAveraging(features)
-      case _ => throw new IllegalArgumentException(
-        s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.")
-    }
-  }
-
-  /**
-   * Predict values for the given data set.
-   *
-   * @param features RDD representing data points to be predicted
-   * @return RDD[Double] where each entry contains the corresponding prediction
-   */
-  def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
-
-  /**
-   * Print a summary of the model.
-   */
-  override def toString: String = {
-    algo match {
-      case Classification =>
-        s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n"
-      case Regression =>
-        s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n"
-      case _ => throw new IllegalArgumentException(
-        s"WeightedEnsembleModel given unknown algo parameter: $algo.")
-    }
-  }
-
-  /**
-   * Print the full model to a string.
-   */
-  def toDebugString: String = {
-    val header = toString + "\n"
-    header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) =>
-      s"  Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
-    }.fold("")(_ + _)
-  }
-
-  /**
-   * Get number of trees in forest.
-   */
-  def numWeakHypotheses: Int = weakHypotheses.size
-
-  // TODO: Remove these helpers methods once class is generalized to support any base learning
-  // algorithms.
-
-  /**
-   * Get total number of nodes, summed over all trees in the forest.
-   */
-  def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
new file mode 100644
index 0000000..2299711
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import scala.collection.mutable
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ * Represents a random forest model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ */
+@Experimental
+class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
+  extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
+    combiningStrategy = if (algo == Classification) Vote else Average) {
+
+  require(trees.forall(_.algo == algo))
+}
+
+/**
+ * :: Experimental ::
+ * Represents a gradient boosted trees model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ * @param treeWeights tree ensemble weights
+ */
+@Experimental
+class GradientBoostedTreesModel(
+    override val algo: Algo,
+    override val trees: Array[DecisionTreeModel],
+    override val treeWeights: Array[Double])
+  extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) {
+
+  require(trees.size == treeWeights.size)
+}
+
+/**
+ * Represents a tree ensemble model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ * @param treeWeights tree ensemble weights
+ * @param combiningStrategy strategy for combining the predictions, not used for regression.
+ */
+private[tree] sealed class TreeEnsembleModel(
+    protected val algo: Algo,
+    protected val trees: Array[DecisionTreeModel],
+    protected val treeWeights: Array[Double],
+    protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
+
+  require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.")
+
+  private val sumWeights = math.max(treeWeights.sum, 1e-15)
+
+  /**
+   * Predicts for a single data point using the weighted sum of ensemble predictions.
+   *
+   * @param features array representing a single data point
+   * @return predicted category from the trained model
+   */
+  private def predictBySumming(features: Vector): Double = {
+    val treePredictions = trees.map(_.predict(features))
+    blas.ddot(numTrees, treePredictions, 1, treeWeights, 1)
+  }
+
+  /**
+   * Classifies a single data point based on (weighted) majority votes.
+   */
+  private def predictByVoting(features: Vector): Double = {
+    val votes = mutable.Map.empty[Int, Double]
+    trees.view.zip(treeWeights).foreach { case (tree, weight) =>
+      val prediction = tree.predict(features).toInt
+      votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
+    }
+    votes.maxBy(_._2)._1
+  }
+
+  /**
+   * Predict values for a single data point using the model trained.
+   *
+   * @param features array representing a single data point
+   * @return predicted category from the trained model
+   */
+  def predict(features: Vector): Double = {
+    (algo, combiningStrategy) match {
+      case (Regression, Sum) =>
+        predictBySumming(features)
+      case (Regression, Average) =>
+        predictBySumming(features) / sumWeights
+      case (Classification, Sum) => // binary classification
+        val prediction = predictBySumming(features)
+        // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
+        if (prediction > 0.0) 1.0 else 0.0
+      case (Classification, Vote) =>
+        predictByVoting(features)
+      case _ =>
+        throw new IllegalArgumentException(
+          "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " +
+            s"($algo, $combiningStrategy).")
+    }
+  }
+
+  /**
+   * Predict values for the given data set.
+   *
+   * @param features RDD representing data points to be predicted
+   * @return RDD[Double] where each entry contains the corresponding prediction
+   */
+  def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
+
+  /**
+   * Java-friendly version of [[org.apache.spark.mllib.tree.model.TreeEnsembleModel#predict]].
+   */
+  def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = {
+    predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
+  }
+
+  /**
+   * Print a summary of the model.
+   */
+  override def toString: String = {
+    algo match {
+      case Classification =>
+        s"TreeEnsembleModel classifier with $numTrees trees\n"
+      case Regression =>
+        s"TreeEnsembleModel regressor with $numTrees trees\n"
+      case _ => throw new IllegalArgumentException(
+        s"TreeEnsembleModel given unknown algo parameter: $algo.")
+    }
+  }
+
+  /**
+   * Print the full model to a string.
+   */
+  def toDebugString: String = {
+    val header = toString + "\n"
+    header + trees.zipWithIndex.map { case (tree, treeIndex) =>
+      s"  Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+    }.fold("")(_ + _)
+  }
+
+  /**
+   * Get number of trees in forest.
+   */
+  def numTrees: Int = trees.size
+
+  /**
+   * Get total number of nodes, summed over all trees in the forest.
+   */
+  def totalNumNodes: Int = trees.map(_.numNodes).sum
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/15cacc81/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
index 2c281a1..9925aae 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -74,7 +74,7 @@ public class JavaDecisionTreeSuite implements Serializable {
         maxBins, categoricalFeaturesInfo);
 
     DecisionTree learner = new DecisionTree(strategy);
-    DecisionTreeModel model = learner.train(rdd.rdd());
+    DecisionTreeModel model = learner.run(rdd.rdd());
 
     int numCorrect = validatePrediction(arr, model);
     Assert.assertTrue(numCorrect == rdd.count());


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org