You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/05 19:33:19 UTC

git commit: [SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java

Repository: spark
Updated Branches:
  refs/heads/master 5f13759d3 -> 5b3b6f6f5


[SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java

### Summary

* Made it easier to construct default Strategy and BoostingStrategy and to set parameters using simple types.
* Added Scala and Java examples for GradientBoostedTrees
* small cleanups and fixes

### Details

GradientBoosting bug fixes (“bug” = bad default options)
* Force boostingStrategy.weakLearnerParams.algo = Regression
* Force boostingStrategy.weakLearnerParams.impurity = impurity.Variance
* Only persist data if not yet persisted (since it causes an error if persisted twice)

BoostingStrategy
* numEstimators: renamed to numIterations
* removed subsamplingRate (duplicated by Strategy)
* removed categoricalFeaturesInfo since it belongs with the weak learner params (since boosting can be oblivious to feature type)
* Changed algo to var (not val) and added BeanProperty, with overload taking String argument
* Added assertValid() method
* Updated defaultParams() method and eliminated defaultWeakLearnerParams() since that belongs in Strategy

Strategy (for DecisionTree)
* Changed algo to var (not val) and added BeanProperty, with overload taking String argument
* Added setCategoricalFeaturesInfo method taking Java Map.
* Cleaned up assertValid
* Changed val’s to def’s since parameters can now be changed.

CC: manishamde mengxr codedeft

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #3094 from jkbradley/gbt-api and squashes the following commits:

7a27e22 [Joseph K. Bradley] scalastyle fix
52013d5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into gbt-api
e9b8410 [Joseph K. Bradley] Summary of changes


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b3b6f6f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b3b6f6f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b3b6f6f

Branch: refs/heads/master
Commit: 5b3b6f6f5f029164d7749366506e142b104c1d43
Parents: 5f13759
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Wed Nov 5 10:33:13 2014 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Nov 5 10:33:13 2014 -0800

----------------------------------------------------------------------
 .../mllib/JavaGradientBoostedTrees.java         | 126 ++++++++++++++
 .../examples/mllib/DecisionTreeRunner.scala     |  64 ++++---
 .../examples/mllib/GradientBoostedTrees.scala   | 146 ++++++++++++++++
 .../spark/mllib/tree/GradientBoosting.scala     | 169 ++++++-------------
 .../tree/configuration/BoostingStrategy.scala   |  78 ++++-----
 .../mllib/tree/configuration/Strategy.scala     |  51 ++++--
 .../mllib/tree/GradientBoostingSuite.scala      |  34 ++--
 7 files changed, 462 insertions(+), 206 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5b3b6f6f/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
new file mode 100644
index 0000000..1af2067
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple2;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.GradientBoosting;
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
+import org.apache.spark.mllib.util.MLUtils;
+
+/**
+ * Classification and regression using gradient-boosted decision trees.
+ */
+public final class JavaGradientBoostedTrees {
+
+  private static void usage() {
+    System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
+        " <Classification/Regression>");
+    System.exit(-1);
+  }
+
+  public static void main(String[] args) {
+    String datapath = "data/mllib/sample_libsvm_data.txt";
+    String algo = "Classification";
+    if (args.length >= 1) {
+      datapath = args[0];
+    }
+    if (args.length >= 2) {
+      algo = args[1];
+    }
+    if (args.length > 2) {
+      usage();
+    }
+    SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
+    JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
+
+    // Set parameters.
+    //  Note: All features are treated as continuous.
+    BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
+    boostingStrategy.setNumIterations(10);
+    boostingStrategy.weakLearnerParams().setMaxDepth(5);
+
+    if (algo.equals("Classification")) {
+      // Compute the number of classes from the data.
+      Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
+        @Override public Double call(LabeledPoint p) {
+          return p.label();
+        }
+      }).countByValue().size();
+      boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
+
+      // Train a GradientBoosting model for classification.
+      final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
+
+      // Evaluate model on training instances and compute training error
+      JavaPairRDD<Double, Double> predictionAndLabel =
+          data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+            @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+              return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+            }
+          });
+      Double trainErr =
+          1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
+            @Override public Boolean call(Tuple2<Double, Double> pl) {
+              return !pl._1().equals(pl._2());
+            }
+          }).count() / data.count();
+      System.out.println("Training error: " + trainErr);
+      System.out.println("Learned classification tree model:\n" + model);
+    } else if (algo.equals("Regression")) {
+      // Train a GradientBoosting model for classification.
+      final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
+
+      // Evaluate model on training instances and compute training error
+      JavaPairRDD<Double, Double> predictionAndLabel =
+          data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+            @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+              return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+            }
+          });
+      Double trainMSE =
+          predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
+            @Override public Double call(Tuple2<Double, Double> pl) {
+              Double diff = pl._1() - pl._2();
+              return diff * diff;
+            }
+          }).reduce(new Function2<Double, Double, Double>() {
+            @Override public Double call(Double a, Double b) {
+              return a + b;
+            }
+          }) / data.count();
+      System.out.println("Training Mean Squared Error: " + trainMSE);
+      System.out.println("Learned regression tree model:\n" + model);
+    } else {
+      usage();
+    }
+
+    sc.stop();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b3b6f6f/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 49751a3..63f02cf 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -154,20 +154,30 @@ object DecisionTreeRunner {
     }
   }
 
-  def run(params: Params) {
-
-    val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
-    val sc = new SparkContext(conf)
-
-    println(s"DecisionTreeRunner with parameters:\n$params")
-
+  /**
+   * Load training and test data from files.
+   * @param input  Path to input dataset.
+   * @param dataFormat  "libsvm" or "dense"
+   * @param testInput  Path to test dataset.
+   * @param algo  Classification or Regression
+   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
+   * @return  (training dataset, test dataset, number of classes),
+   *          where the number of classes is inferred from data (and set to 0 for Regression)
+   */
+  private[mllib] def loadDatasets(
+      sc: SparkContext,
+      input: String,
+      dataFormat: String,
+      testInput: String,
+      algo: Algo,
+      fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = {
     // Load training data and cache it.
-    val origExamples = params.dataFormat match {
-      case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
-      case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
+    val origExamples = dataFormat match {
+      case "dense" => MLUtils.loadLabeledPoints(sc, input).cache()
+      case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache()
     }
     // For classification, re-index classes if needed.
-    val (examples, classIndexMap, numClasses) = params.algo match {
+    val (examples, classIndexMap, numClasses) = algo match {
       case Classification => {
         // classCounts: class --> # examples in class
         val classCounts = origExamples.map(_.label).countByValue()
@@ -205,14 +215,14 @@ object DecisionTreeRunner {
     }
 
     // Create training, test sets.
-    val splits = if (params.testInput != "") {
+    val splits = if (testInput != "") {
       // Load testInput.
       val numFeatures = examples.take(1)(0).features.size
-      val origTestExamples = params.dataFormat match {
-        case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
-        case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures)
+      val origTestExamples = dataFormat match {
+        case "dense" => MLUtils.loadLabeledPoints(sc, testInput)
+        case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures)
       }
-      params.algo match {
+      algo match {
         case Classification => {
           // classCounts: class --> # examples in class
           val testExamples = {
@@ -229,17 +239,31 @@ object DecisionTreeRunner {
       }
     } else {
       // Split input into training, test.
-      examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+      examples.randomSplit(Array(1.0 - fracTest, fracTest))
     }
     val training = splits(0).cache()
     val test = splits(1).cache()
+
     val numTraining = training.count()
     val numTest = test.count()
-
     println(s"numTraining = $numTraining, numTest = $numTest.")
 
     examples.unpersist(blocking = false)
 
+    (training, test, numClasses)
+  }
+
+  def run(params: Params) {
+
+    val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
+    val sc = new SparkContext(conf)
+
+    println(s"DecisionTreeRunner with parameters:\n$params")
+
+    // Load training and test data and cache it.
+    val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat,
+      params.testInput, params.algo, params.fracTest)
+
     val impurityCalculator = params.impurity match {
       case Gini => impurity.Gini
       case Entropy => impurity.Entropy
@@ -338,7 +362,9 @@ object DecisionTreeRunner {
   /**
    * Calculates the mean squared error for regression.
    */
-  private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+  private[mllib] def meanSquaredError(
+      tree: WeightedEnsembleModel,
+      data: RDD[LabeledPoint]): Double = {
     data.map { y =>
       val err = tree.predict(y.features) - y.label
       err * err

http://git-wip-us.apache.org/repos/asf/spark/blob/5b3b6f6f/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
new file mode 100644
index 0000000..9b6db01
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.tree.GradientBoosting
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
+import org.apache.spark.util.Utils
+
+/**
+ * An example runner for Gradient Boosting using decision trees as weak learners. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ *       To include categorical features, modify categoricalFeaturesInfo.
+ */
+object GradientBoostedTrees {
+
+  case class Params(
+      input: String = null,
+      testInput: String = "",
+      dataFormat: String = "libsvm",
+      algo: String = "Classification",
+      maxDepth: Int = 5,
+      numIterations: Int = 10,
+      fracTest: Double = 0.2) extends AbstractParams[Params]
+
+  def main(args: Array[String]) {
+    val defaultParams = Params()
+
+    val parser = new OptionParser[Params]("GradientBoostedTrees") {
+      head("GradientBoostedTrees: an example decision tree app.")
+      opt[String]("algo")
+        .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
+        .action((x, c) => c.copy(algo = x))
+      opt[Int]("maxDepth")
+        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+        .action((x, c) => c.copy(maxDepth = x))
+      opt[Int]("numIterations")
+        .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
+        .action((x, c) => c.copy(numIterations = x))
+      opt[Double]("fracTest")
+        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
+          s"this option is ignored. default: ${defaultParams.fracTest}")
+        .action((x, c) => c.copy(fracTest = x))
+      opt[String]("testInput")
+        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
+          s" default: ${defaultParams.testInput}")
+        .action((x, c) => c.copy(testInput = x))
+      opt[String]("<dataFormat>")
+        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+        .action((x, c) => c.copy(dataFormat = x))
+      arg[String]("<input>")
+        .text("input path to labeled examples")
+        .required()
+        .action((x, c) => c.copy(input = x))
+      checkConfig { params =>
+        if (params.fracTest < 0 || params.fracTest > 1) {
+          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+        } else {
+          success
+        }
+      }
+    }
+
+    parser.parse(args, defaultParams).map { params =>
+      run(params)
+    }.getOrElse {
+      sys.exit(1)
+    }
+  }
+
+  def run(params: Params) {
+
+    val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
+    val sc = new SparkContext(conf)
+
+    println(s"GradientBoostedTrees with parameters:\n$params")
+
+    // Load training and test data and cache it.
+    val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
+      params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
+
+    val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
+    boostingStrategy.numClassesForClassification = numClasses
+    boostingStrategy.numIterations = params.numIterations
+    boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
+
+    val randomSeed = Utils.random.nextInt()
+    if (params.algo == "Classification") {
+      val startTime = System.nanoTime()
+      val model = GradientBoosting.trainClassifier(training, boostingStrategy)
+      val elapsedTime = (System.nanoTime() - startTime) / 1e9
+      println(s"Training time: $elapsedTime seconds")
+      if (model.totalNumNodes < 30) {
+        println(model.toDebugString) // Print full model.
+      } else {
+        println(model) // Print model summary.
+      }
+      val trainAccuracy =
+        new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+          .precision
+      println(s"Train accuracy = $trainAccuracy")
+      val testAccuracy =
+        new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+      println(s"Test accuracy = $testAccuracy")
+    } else if (params.algo == "Regression") {
+      val startTime = System.nanoTime()
+      val model = GradientBoosting.trainRegressor(training, boostingStrategy)
+      val elapsedTime = (System.nanoTime() - startTime) / 1e9
+      println(s"Training time: $elapsedTime seconds")
+      if (model.totalNumNodes < 30) {
+        println(model.toDebugString) // Print full model.
+      } else {
+        println(model) // Print model summary.
+      }
+      val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
+      println(s"Train mean squared error = $trainMSE")
+      val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
+      println(s"Test mean squared error = $testMSE")
+    }
+
+    sc.stop()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b3b6f6f/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
index 1a84720..f729344 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
@@ -17,30 +17,49 @@
 
 package org.apache.spark.mllib.tree
 
-import scala.collection.JavaConverters._
-
+import org.apache.spark.Logging
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy}
-import org.apache.spark.Logging
-import org.apache.spark.mllib.tree.impl.TimeTracker
-import org.apache.spark.mllib.tree.loss.Losses
-import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
 import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy
 import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
 
 /**
  * :: Experimental ::
- * A class that implements gradient boosting for regression and binary classification problems.
+ * A class that implements Stochastic Gradient Boosting
+ * for regression and binary classification problems.
+ *
+ * The implementation is based upon:
+ *   J.H. Friedman.  "Stochastic Gradient Boosting."  1999.
+ *
+ * Notes:
+ *  - This currently can be run with several loss functions.  However, only SquaredError is
+ *    fully supported.  Specifically, the loss function should be used to compute the gradient
+ *    (to re-label training instances on each iteration) and to weight weak hypotheses.
+ *    Currently, gradients are computed correctly for the available loss functions,
+ *    but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
+ *    Running with those losses will likely behave reasonably, but lacks the same guarantees.
+ *
  * @param boostingStrategy Parameters for the gradient boosting algorithm
  */
 @Experimental
 class GradientBoosting (
     private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
 
+  boostingStrategy.weakLearnerParams.algo = Regression
+  boostingStrategy.weakLearnerParams.impurity = impurity.Variance
+
+  // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
+  boostingStrategy.weakLearnerParams.numClassesForClassification =
+    boostingStrategy.numClassesForClassification
+
+  boostingStrategy.assertValid()
+
   /**
    * Method to train a gradient boosting model
    * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
@@ -51,6 +70,7 @@ class GradientBoosting (
     algo match {
       case Regression => GradientBoosting.boost(input, boostingStrategy)
       case Classification =>
+        // Map labels to -1, +1 so binary classification can be treated as regression.
         val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
         GradientBoosting.boost(remappedInput, boostingStrategy)
       case _ =>
@@ -118,120 +138,32 @@ object GradientBoosting extends Logging {
   }
 
   /**
-   * Method to train a gradient boosting binary classification model.
-   *
-   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
-   *              For regression, labels are real numbers.
-   * @param numEstimators Number of estimators used in boosting stages. In other words,
-   *                      number of boosting iterations performed.
-   * @param loss Loss function used for minimization during gradient boosting.
-   * @param learningRate Learning rate for shrinking the contribution of each estimator. The
-   *                     learning rate should be between in the interval (0, 1]
-   * @param subsamplingRate  Fraction of the training data used for learning the decision tree.
-   * @param numClassesForClassification Number of classes for classification.
-   *                                    (Ignored for regression.)
-   * @param categoricalFeaturesInfo A map storing information about the categorical variables and
-   *                                the number of discrete values they take. For example,
-   *                                an entry (n -> k) implies the feature n is categorical with k
-   *                                categories 0, 1, 2, ... , k-1. It's important to note that
-   *                                features are zero-indexed.
-   * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
-   *                          supported.)
-   * @return WeightedEnsembleModel that can be used for prediction
+   * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]]
    */
-  def trainClassifier(
-      input: RDD[LabeledPoint],
-      numEstimators: Int,
-      loss: String,
-      learningRate: Double,
-      subsamplingRate: Double,
-      numClassesForClassification: Int,
-      categoricalFeaturesInfo: Map[Int, Int],
-      weakLearnerParams: Strategy): WeightedEnsembleModel = {
-    val lossType = Losses.fromString(loss)
-    val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType,
-      learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
-      weakLearnerParams)
-    new GradientBoosting(boostingStrategy).train(input)
-  }
-
-  /**
-   * Method to train a gradient boosting regression model.
-   *
-   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
-   *              For regression, labels are real numbers.
-   * @param numEstimators Number of estimators used in boosting stages. In other words,
-   *                      number of boosting iterations performed.
-   * @param loss Loss function used for minimization during gradient boosting.
-   * @param learningRate Learning rate for shrinking the contribution of each estimator. The
-   *                     learning rate should be between in the interval (0, 1]
-   * @param subsamplingRate  Fraction of the training data used for learning the decision tree.
-   * @param numClassesForClassification Number of classes for classification.
-   *                                    (Ignored for regression.)
-   * @param categoricalFeaturesInfo A map storing information about the categorical variables and
-   *                                the number of discrete values they take. For example,
-   *                                an entry (n -> k) implies the feature n is categorical with k
-   *                                categories 0, 1, 2, ... , k-1. It's important to note that
-   *                                features are zero-indexed.
-   * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
-   *                          supported.)
-   * @return WeightedEnsembleModel that can be used for prediction
-   */
-  def trainRegressor(
-       input: RDD[LabeledPoint],
-       numEstimators: Int,
-       loss: String,
-       learningRate: Double,
-       subsamplingRate: Double,
-       numClassesForClassification: Int,
-       categoricalFeaturesInfo: Map[Int, Int],
-       weakLearnerParams: Strategy): WeightedEnsembleModel = {
-    val lossType = Losses.fromString(loss)
-    val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType,
-      learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
-      weakLearnerParams)
-    new GradientBoosting(boostingStrategy).train(input)
+  def train(
+    input: JavaRDD[LabeledPoint],
+    boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+    train(input.rdd, boostingStrategy)
   }
 
   /**
    * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
    */
   def trainClassifier(
-      input: RDD[LabeledPoint],
-      numEstimators: Int,
-      loss: String,
-      learningRate: Double,
-      subsamplingRate: Double,
-      numClassesForClassification: Int,
-      categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer],
-      weakLearnerParams: Strategy): WeightedEnsembleModel = {
-    trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate,
-      numClassesForClassification,
-      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
-      weakLearnerParams)
+      input: JavaRDD[LabeledPoint],
+      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+    trainClassifier(input.rdd, boostingStrategy)
   }
 
   /**
    * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
    */
   def trainRegressor(
-      input: RDD[LabeledPoint],
-      numEstimators: Int,
-      loss: String,
-      learningRate: Double,
-      subsamplingRate: Double,
-      numClassesForClassification: Int,
-      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
-      weakLearnerParams: Strategy): WeightedEnsembleModel = {
-    trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate,
-      numClassesForClassification,
-      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
-      weakLearnerParams)
+      input: JavaRDD[LabeledPoint],
+      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+    trainRegressor(input.rdd, boostingStrategy)
   }
 
-
   /**
    * Internal method for performing regression using trees as base learners.
    * @param input training dataset
@@ -247,15 +179,17 @@ object GradientBoosting extends Logging {
     timer.start("init")
 
     // Initialize gradient boosting parameters
-    val numEstimators = boostingStrategy.numEstimators
-    val baseLearners = new Array[DecisionTreeModel](numEstimators)
-    val baseLearnerWeights = new Array[Double](numEstimators)
+    val numIterations = boostingStrategy.numIterations
+    val baseLearners = new Array[DecisionTreeModel](numIterations)
+    val baseLearnerWeights = new Array[Double](numIterations)
     val loss = boostingStrategy.loss
     val learningRate = boostingStrategy.learningRate
     val strategy = boostingStrategy.weakLearnerParams
 
     // Cache input
-    input.persist(StorageLevel.MEMORY_AND_DISK)
+    if (input.getStorageLevel == StorageLevel.NONE) {
+      input.persist(StorageLevel.MEMORY_AND_DISK)
+    }
 
     timer.stop("init")
 
@@ -264,7 +198,7 @@ object GradientBoosting extends Logging {
     logDebug("##########")
     var data = input
 
-    // 1. Initialize tree
+    // Initialize tree
     timer.start("building tree 0")
     val firstTreeModel = new DecisionTree(strategy).train(data)
     baseLearners(0) = firstTreeModel
@@ -280,7 +214,7 @@ object GradientBoosting extends Logging {
       point.features))
 
     var m = 1
-    while (m < numEstimators) {
+    while (m < numIterations) {
       timer.start(s"building tree $m")
       logDebug("###################################################")
       logDebug("Gradient boosting tree iteration " + m)
@@ -289,6 +223,9 @@ object GradientBoosting extends Logging {
       timer.stop(s"building tree $m")
       // Create partial model
       baseLearners(m) = model
+      // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+      //       Technically, the weight should be optimized for the particular loss.
+      //       However, the behavior should be reasonable, though not optimal.
       baseLearnerWeights(m) = learningRate
       // Note: A model of type regression is used since we require raw prediction
       val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
@@ -305,8 +242,6 @@ object GradientBoosting extends Logging {
     logInfo("Internal timing for DecisionTree:")
     logInfo(s"$timer")
 
-
-    // 3. Output classifier
     new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
 
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b3b6f6f/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 501d9ff..abbda04 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -21,7 +21,6 @@ import scala.beans.BeanProperty
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
 import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
 
 /**
@@ -30,46 +29,58 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
  * @param algo  Learning goal.  Supported:
  *              [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
  *              [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
- * @param numEstimators Number of estimators used in boosting stages. In other words,
- *                      number of boosting iterations performed.
+ * @param numIterations Number of iterations of boosting.  In other words, the number of
+ *                      weak hypotheses used in the final model.
  * @param loss Loss function used for minimization during gradient boosting.
  * @param learningRate Learning rate for shrinking the contribution of each estimator. The
  *                     learning rate should be between in the interval (0, 1]
- * @param subsamplingRate  Fraction of the training data used for learning the decision tree.
  * @param numClassesForClassification Number of classes for classification.
  *                                    (Ignored for regression.)
+ *                                    This setting overrides any setting in [[weakLearnerParams]].
  *                                    Default value is 2 (binary classification).
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
- *                                number of discrete values they take. For example, an entry (n ->
- *                                k) implies the feature n is categorical with k categories 0,
- *                                1, 2, ... , k-1. It's important to note that features are
- *                                zero-indexed.
  * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
  *                          supported.
  */
 @Experimental
 case class BoostingStrategy(
     // Required boosting parameters
-    algo: Algo,
-    @BeanProperty var numEstimators: Int,
+    @BeanProperty var algo: Algo,
+    @BeanProperty var numIterations: Int,
     @BeanProperty var loss: Loss,
     // Optional boosting parameters
     @BeanProperty var learningRate: Double = 0.1,
-    @BeanProperty var subsamplingRate: Double = 1.0,
     @BeanProperty var numClassesForClassification: Int = 2,
-    @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
     @BeanProperty var weakLearnerParams: Strategy) extends Serializable {
 
-  require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " +
-    s"$learningRate.")
-  require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " +
-    s"$learningRate.")
-
   // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
-  weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo
   weakLearnerParams.numClassesForClassification = numClassesForClassification
-  weakLearnerParams.subsamplingRate = subsamplingRate
 
+  /**
+   * Sets Algorithm using a String.
+   */
+  def setAlgo(algo: String): Unit = algo match {
+    case "Classification" => setAlgo(Classification)
+    case "Regression" => setAlgo(Regression)
+  }
+
+  /**
+   * Check validity of parameters.
+   * Throws exception if invalid.
+   */
+  private[tree] def assertValid(): Unit = {
+    algo match {
+      case Classification =>
+        require(numClassesForClassification == 2)
+      case Regression =>
+        // nothing
+      case _ =>
+        throw new IllegalArgumentException(
+          s"BoostingStrategy given invalid algo parameter: $algo." +
+            s"  Valid settings are: Classification, Regression.")
+    }
+    require(learningRate > 0 && learningRate <= 1,
+      "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.")
+  }
 }
 
 @Experimental
@@ -82,28 +93,17 @@ object BoostingStrategy {
    *             [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
    * @return Configuration for boosting algorithm
    */
-  def defaultParams(algo: Algo): BoostingStrategy = {
-    val treeStrategy = defaultWeakLearnerParams(algo)
+  def defaultParams(algo: String): BoostingStrategy = {
+    val treeStrategy = Strategy.defaultStrategy("Regression")
+    treeStrategy.maxDepth = 3
     algo match {
-      case Classification =>
-        new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy)
-      case Regression =>
-        new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy)
+      case "Classification" =>
+        new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy)
+      case "Regression" =>
+        new BoostingStrategy(Algo.withName(algo), 100, SquaredError,
+          weakLearnerParams = treeStrategy)
       case _ =>
         throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
     }
   }
-
-  /**
-   * Returns default configuration for the weak learner (decision tree) algorithm
-   * @param algo   Learning goal.  Supported:
-   *              [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
-   *              [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
-   * @return Configuration for weak learner
-   */
-  def defaultWeakLearnerParams(algo: Algo): Strategy = {
-    // Note: Regression tree used even for classification for GBT.
-    new Strategy(Regression, Variance, 3)
-  }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b3b6f6f/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index d09295c..b5b1f82 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -70,7 +70,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
  */
 @Experimental
 class Strategy (
-    val algo: Algo,
+    @BeanProperty var algo: Algo,
     @BeanProperty var impurity: Impurity,
     @BeanProperty var maxDepth: Int,
     @BeanProperty var numClassesForClassification: Int = 2,
@@ -85,17 +85,9 @@ class Strategy (
     @BeanProperty var checkpointDir: Option[String] = None,
     @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
 
-  if (algo == Classification) {
-    require(numClassesForClassification >= 2)
-  }
-  require(minInstancesPerNode >= 1,
-    s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
-  require(maxMemoryInMB <= 10240,
-    s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
-
-  val isMulticlassClassification =
+  def isMulticlassClassification =
     algo == Classification && numClassesForClassification > 2
-  val isMulticlassWithCategoricalFeatures
+  def isMulticlassWithCategoricalFeatures
     = isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
 
   /**
@@ -113,6 +105,23 @@ class Strategy (
   }
 
   /**
+   * Sets Algorithm using a String.
+   */
+  def setAlgo(algo: String): Unit = algo match {
+    case "Classification" => setAlgo(Classification)
+    case "Regression" => setAlgo(Regression)
+  }
+
+  /**
+   * Sets categoricalFeaturesInfo using a Java Map.
+   */
+  def setCategoricalFeaturesInfo(
+      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = {
+    setCategoricalFeaturesInfo(
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+  }
+
+  /**
    * Check validity of parameters.
    * Throws exception if invalid.
    */
@@ -143,6 +152,26 @@ class Strategy (
         s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
         s" feature $feature has $arity categories.  The number of categories should be >= 2.")
     }
+    require(minInstancesPerNode >= 1,
+      s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+    require(maxMemoryInMB <= 10240,
+      s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
   }
+}
+
+@Experimental
+object Strategy {
 
+  /**
+   * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+   * @param algo  "Classification" or "Regression"
+   */
+  def defaultStrategy(algo: String): Strategy = algo match {
+    case "Classification" =>
+      new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
+        numClassesForClassification = 2)
+    case "Regression" =>
+      new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
+        numClassesForClassification = 0)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b3b6f6f/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
index 970fff8..99a02ed 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
@@ -22,9 +22,8 @@ import org.scalatest.FunSuite
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
-import org.apache.spark.mllib.tree.impurity.{Variance, Gini}
+import org.apache.spark.mllib.tree.impurity.Variance
 import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
 
 import org.apache.spark.mllib.util.LocalSparkContext
 
@@ -34,9 +33,8 @@ import org.apache.spark.mllib.util.LocalSparkContext
 class GradientBoostingSuite extends FunSuite with LocalSparkContext {
 
   test("Regression with continuous features: SquaredError") {
-
     GradientBoostingSuite.testCombinations.foreach {
-      case (numEstimators, learningRate, subsamplingRate) =>
+      case (numIterations, learningRate, subsamplingRate) =>
         val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
         val rdd = sc.parallelize(arr)
         val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -48,11 +46,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
 
         val dt = DecisionTree.train(remappedInput, treeStrategy)
 
-        val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
-          subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+        val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
+          learningRate, 1, treeStrategy)
 
         val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
-        assert(gbt.weakHypotheses.size === numEstimators)
+        assert(gbt.weakHypotheses.size === numIterations)
         val gbtTree = gbt.weakHypotheses(0)
 
         EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
@@ -63,9 +61,8 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
   }
 
   test("Regression with continuous features: Absolute Error") {
-
     GradientBoostingSuite.testCombinations.foreach {
-      case (numEstimators, learningRate, subsamplingRate) =>
+      case (numIterations, learningRate, subsamplingRate) =>
         val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
         val rdd = sc.parallelize(arr)
         val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -77,11 +74,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
 
         val dt = DecisionTree.train(remappedInput, treeStrategy)
 
-        val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
-          subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+        val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
+          learningRate, numClassesForClassification = 2, treeStrategy)
 
         val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
-        assert(gbt.weakHypotheses.size === numEstimators)
+        assert(gbt.weakHypotheses.size === numIterations)
         val gbtTree = gbt.weakHypotheses(0)
 
         EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
@@ -91,11 +88,9 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
     }
   }
 
-
   test("Binary classification with continuous features: Log Loss") {
-
     GradientBoostingSuite.testCombinations.foreach {
-      case (numEstimators, learningRate, subsamplingRate) =>
+      case (numIterations, learningRate, subsamplingRate) =>
         val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
         val rdd = sc.parallelize(arr)
         val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -107,11 +102,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
 
         val dt = DecisionTree.train(remappedInput, treeStrategy)
 
-        val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss,
-          subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+        val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss,
+          learningRate, numClassesForClassification = 2, treeStrategy)
 
         val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy)
-        assert(gbt.weakHypotheses.size === numEstimators)
+        assert(gbt.weakHypotheses.size === numIterations)
         val gbtTree = gbt.weakHypotheses(0)
 
         EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
@@ -126,7 +121,6 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
 object GradientBoostingSuite {
 
   // Combinations for estimators, learning rates and subsamplingRate
-  val testCombinations
-    = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
+  val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org