You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/04/17 22:15:41 UTC

[2/2] spark git commit: [SPARK-6113] [ml] Stabilize DecisionTree API

[SPARK-6113] [ml] Stabilize DecisionTree API

This is a PR for cleaning up and finalizing the DecisionTree API.  PRs for ensembles will follow once this is merged.

### Goal

Here is the description copied from the JIRA (for both trees and ensembles):

> **Issue**: The APIs for DecisionTree and ensembles (RandomForests and GradientBoostedTrees) have been experimental for a long time. The API has become very convoluted because trees and ensembles have many, many variants, some of which we have added incrementally without a long-term design.
> **Proposal**: This JIRA is for discussing changes required to finalize the APIs. After we discuss, I will make a PR to update the APIs and make them non-Experimental. This will require making many breaking changes; see the design doc for details.
> **[Design doc](https://docs.google.com/document/d/1rJ_DZinyDG3PkYkAKSsQlY0QgCeefn4hUv7GsPkzBP4)** : This outlines current issues and the proposed API.

Overall code layout:
* The old API in mllib.tree.* will remain the same.
* The new API will reside in ml.classification.* and ml.regression.*

### Summary of changes

Old API
* Exactly the same, except I made 1 method in Loss private (but that is not a breaking change since that method was introduced after the Spark 1.3 release).

New APIs
* Under Pipeline API
* The new API preserves functionality, except:
  * New API does NOT store prob (probability of label in classification).  I want to have it store the full vector of probabilities but feel that should be in a later PR.
* Use abstractions for parameters, estimators, and models to avoid code duplication
* Limit parameters to relevant algorithms
* For enum-like types, only expose Strings
  * We can make these pluggable later on by adding new parameters.  That is a far-future item.

Test suites
* I organized DecisionTreeSuite, but I made absolutely no changes to the tests themselves.
* The test suites for the new API only test (a) similarity with the results of the old API and (b) elements of the new API.
  * After code is moved to this new API, we should move the tests from the old suites which test the internals.

### Details

#### Changed names

Parameters
* useNodeIdCache -> cacheNodeIds

#### Other changes

* Split: Changed categories to set instead of list

#### Non-decision tree changes
* AttributeGroup
  * Added parentheses to toMetadata, toStructField methods (These were removed in a previous PR, but I ran into 1 issue with the Scala compiler not being able to disambiguate between a toMetadata method with no parentheses and a toMetadata method which takes 1 argument.)
* Attributes
  * Renamed: toMetadata -> toMetadataImpl
  * Added toMetadata methods which return ML metadata (keyed with “ML_ATTR”)
  * NominalAttribute: Added getNumValues method which examines both numValues and values.
* Params.inheritValues: Checks whether the parent param really belongs to the child (to allow Estimator-Model pairs with different sets of parameters)

### Questions for reviewers

* Is "DecisionTreeClassificationModel" too long a name?
* Is this OK in the docs?
```
class DecisionTreeRegressor extends TreeRegressor[DecisionTreeRegressionModel] with DecisionTreeParams[DecisionTreeRegressor] with TreeRegressorParams[DecisionTreeRegressor]
```

### Future

We should open up the abstractions at some point.  E.g., it would be useful to be able to set tree-related parameters in 1 place and then pass those to multiple tree-based algorithms.

Follow-up JIRAs will be (in this order):
* Tree ensembles
* Deprecate old tree code
* Move DecisionTree implementation code to new API.
* Move tests from the old suites which test the internals.
* Update programming guide
* Python API
* Change RandomForest* to always use bootstrapping, even when numTrees = 1
* Provide the probability of the predicted label for classification.  After we move code to the new API and update it to maintain probabilities for all labels, then we can add the probabilities to the new API.

CC: mengxr  manishamde  codedeft  chouqin  MechCoder

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #5530 from jkbradley/dt-api-dt and squashes the following commits:

6aae255 [Joseph K. Bradley] Changed tree abstractions not to take type parameters, and for setters to return this.type instead
ec17947 [Joseph K. Bradley] Updates based on code review.  Main changes were: moving public types from ml.impl.tree to ml.tree, modifying CategoricalSplit to take an Array of categories but store a Set internally, making more types sealed or final
5626c81 [Joseph K. Bradley] style fixes
f8fbd24 [Joseph K. Bradley] imported reorg of DecisionTreeSuite from old PR.  small cleanups
7ef63ed [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example (for real this time)
e11673f [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example
119f407 [Joseph K. Bradley] added DecisionTreeClassifier example
0bdc486 [Joseph K. Bradley] fixed issues after param PR was merged
f9fbb60 [Joseph K. Bradley] Done with DecisionTreeClassifier, but no save/load yet.  Need to add example as well
2532c9a [Joseph K. Bradley] partial move to spark.ml API, not done yet
c72c1a0 [Joseph K. Bradley] Copied changes for common items, plus DecisionTreeClassifier from original PR


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a83571ac
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a83571ac
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a83571ac

Branch: refs/heads/master
Commit: a83571acc938582865efb41645aa1e414f339e46
Parents: 50ab8a6
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Fri Apr 17 13:15:36 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Apr 17 13:15:36 2015 -0700

----------------------------------------------------------------------
 .../spark/examples/ml/DecisionTreeExample.scala | 322 ++++++++++++++++
 .../spark/ml/attribute/AttributeGroup.scala     |  10 +-
 .../apache/spark/ml/attribute/attributes.scala  |  43 ++-
 .../classification/DecisionTreeClassifier.scala | 155 ++++++++
 .../apache/spark/ml/feature/StringIndexer.scala |   2 +-
 .../apache/spark/ml/impl/tree/treeParams.scala  | 300 +++++++++++++++
 .../scala/org/apache/spark/ml/package.scala     |  12 +
 .../org/apache/spark/ml/param/params.scala      |   3 +-
 .../ml/regression/DecisionTreeRegressor.scala   | 145 +++++++
 .../scala/org/apache/spark/ml/tree/Node.scala   | 205 ++++++++++
 .../scala/org/apache/spark/ml/tree/Split.scala  | 151 ++++++++
 .../org/apache/spark/ml/tree/treeModels.scala   |  60 +++
 .../apache/spark/ml/util/MetadataUtils.scala    |  82 ++++
 .../apache/spark/mllib/tree/DecisionTree.scala  |   5 +-
 .../spark/mllib/tree/GradientBoostedTrees.scala |  12 +-
 .../apache/spark/mllib/tree/RandomForest.scala  |   2 +-
 .../tree/configuration/BoostingStrategy.scala   |  10 +-
 .../spark/mllib/tree/loss/AbsoluteError.scala   |   5 +-
 .../apache/spark/mllib/tree/loss/LogLoss.scala  |   5 +-
 .../org/apache/spark/mllib/tree/loss/Loss.scala |   4 +-
 .../spark/mllib/tree/loss/SquaredError.scala    |   5 +-
 .../mllib/tree/model/DecisionTreeModel.scala    |   4 +-
 .../apache/spark/mllib/tree/model/Node.scala    |   2 +-
 .../mllib/tree/model/treeEnsembleModels.scala   |  32 +-
 .../JavaDecisionTreeClassifierSuite.java        |  98 +++++
 .../JavaDecisionTreeRegressorSuite.java         |  97 +++++
 .../ml/attribute/AttributeGroupSuite.scala      |   4 +-
 .../spark/ml/attribute/AttributeSuite.scala     |  42 +--
 .../DecisionTreeClassifierSuite.scala           | 274 ++++++++++++++
 .../spark/ml/feature/VectorIndexerSuite.scala   |   2 +-
 .../org/apache/spark/ml/impl/TreeTests.scala    | 132 +++++++
 .../regression/DecisionTreeRegressorSuite.scala |  91 +++++
 .../spark/mllib/tree/DecisionTreeSuite.scala    | 373 ++++++++++---------
 33 files changed, 2426 insertions(+), 263 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
new file mode 100644
index 0000000..d4cc8de
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -0,0 +1,322 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.ml.tree.DecisionTreeModel
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
+import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
+import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.{SQLContext, DataFrame}
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.DecisionTreeExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DecisionTreeExample {
+
+  case class Params(
+      input: String = null,
+      testInput: String = "",
+      dataFormat: String = "libsvm",
+      algo: String = "Classification",
+      maxDepth: Int = 5,
+      maxBins: Int = 32,
+      minInstancesPerNode: Int = 1,
+      minInfoGain: Double = 0.0,
+      numTrees: Int = 1,
+      featureSubsetStrategy: String = "auto",
+      fracTest: Double = 0.2,
+      cacheNodeIds: Boolean = false,
+      checkpointDir: Option[String] = None,
+      checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+  def main(args: Array[String]) {
+    val defaultParams = Params()
+
+    val parser = new OptionParser[Params]("DecisionTreeExample") {
+      head("DecisionTreeExample: an example decision tree app.")
+      opt[String]("algo")
+        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
+        .action((x, c) => c.copy(algo = x))
+      opt[Int]("maxDepth")
+        .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+        .action((x, c) => c.copy(maxDepth = x))
+      opt[Int]("maxBins")
+        .text(s"max number of bins, default: ${defaultParams.maxBins}")
+        .action((x, c) => c.copy(maxBins = x))
+      opt[Int]("minInstancesPerNode")
+        .text(s"min number of instances required at child nodes to create the parent split," +
+          s" default: ${defaultParams.minInstancesPerNode}")
+        .action((x, c) => c.copy(minInstancesPerNode = x))
+      opt[Double]("minInfoGain")
+        .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+        .action((x, c) => c.copy(minInfoGain = x))
+      opt[Double]("fracTest")
+        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
+          s"this option is ignored. default: ${defaultParams.fracTest}")
+        .action((x, c) => c.copy(fracTest = x))
+      opt[Boolean]("cacheNodeIds")
+        .text(s"whether to use node Id cache during training, " +
+          s"default: ${defaultParams.cacheNodeIds}")
+        .action((x, c) => c.copy(cacheNodeIds = x))
+      opt[String]("checkpointDir")
+        .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+         s"default: ${defaultParams.checkpointDir match {
+           case Some(strVal) => strVal
+           case None => "None"
+         }}")
+        .action((x, c) => c.copy(checkpointDir = Some(x)))
+      opt[Int]("checkpointInterval")
+        .text(s"how often to checkpoint the node Id cache, " +
+         s"default: ${defaultParams.checkpointInterval}")
+        .action((x, c) => c.copy(checkpointInterval = x))
+      opt[String]("testInput")
+        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
+          s" default: ${defaultParams.testInput}")
+        .action((x, c) => c.copy(testInput = x))
+      opt[String]("<dataFormat>")
+        .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+        .action((x, c) => c.copy(dataFormat = x))
+      arg[String]("<input>")
+        .text("input path to labeled examples")
+        .required()
+        .action((x, c) => c.copy(input = x))
+      checkConfig { params =>
+        if (params.fracTest < 0 || params.fracTest > 1) {
+          failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+        } else {
+          success
+        }
+      }
+    }
+
+    parser.parse(args, defaultParams).map { params =>
+      run(params)
+    }.getOrElse {
+      sys.exit(1)
+    }
+  }
+
+  /** Load a dataset from the given path, using the given format */
+  private[ml] def loadData(
+      sc: SparkContext,
+      path: String,
+      format: String,
+      expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
+    format match {
+      case "dense" => MLUtils.loadLabeledPoints(sc, path)
+      case "libsvm" => expectedNumFeatures match {
+        case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
+        case None => MLUtils.loadLibSVMFile(sc, path)
+      }
+      case _ => throw new IllegalArgumentException(s"Bad data format: $format")
+    }
+  }
+
+  /**
+   * Load training and test data from files.
+   * @param input  Path to input dataset.
+   * @param dataFormat  "libsvm" or "dense"
+   * @param testInput  Path to test dataset.
+   * @param algo  Classification or Regression
+   * @param fracTest  Fraction of input data to hold out for testing.  Ignored if testInput given.
+   * @return  (training dataset, test dataset)
+   */
+  private[ml] def loadDatasets(
+      sc: SparkContext,
+      input: String,
+      dataFormat: String,
+      testInput: String,
+      algo: String,
+      fracTest: Double): (DataFrame, DataFrame) = {
+    val sqlContext = new SQLContext(sc)
+    import sqlContext.implicits._
+
+    // Load training data
+    val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
+
+    // Load or create test set
+    val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
+      // Load testInput.
+      val numFeatures = origExamples.take(1)(0).features.size
+      val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
+      Array(origExamples, origTestExamples)
+    } else {
+      // Split input into training, test.
+      origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345)
+    }
+
+    // For classification, convert labels to Strings since we will index them later with
+    // StringIndexer.
+    def labelsToStrings(data: DataFrame): DataFrame = {
+      algo.toLowerCase match {
+        case "classification" =>
+          data.withColumn("labelString", data("label").cast(StringType))
+        case "regression" =>
+          data
+        case _ =>
+          throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+      }
+    }
+    val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
+
+    (dataframes(0), dataframes(1))
+  }
+
+  def run(params: Params) {
+    val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params")
+    val sc = new SparkContext(conf)
+    params.checkpointDir.foreach(sc.setCheckpointDir)
+    val algo = params.algo.toLowerCase
+
+    println(s"DecisionTreeExample with parameters:\n$params")
+
+    // Load training and test data and cache it.
+    val (training: DataFrame, test: DataFrame) =
+      loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
+
+    val numTraining = training.count()
+    val numTest = test.count()
+    val numFeatures = training.select("features").first().getAs[Vector](0).size
+    println("Loaded data:")
+    println(s"  numTraining = $numTraining, numTest = $numTest")
+    println(s"  numFeatures = $numFeatures")
+
+    // Set up Pipeline
+    val stages = new mutable.ArrayBuffer[PipelineStage]()
+    // (1) For classification, re-index classes.
+    val labelColName = if (algo == "classification") "indexedLabel" else "label"
+    if (algo == "classification") {
+      val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
+      stages += labelIndexer
+    }
+    // (2) Identify categorical features using VectorIndexer.
+    //     Features with more than maxCategories values will be treated as continuous.
+    val featuresIndexer = new VectorIndexer().setInputCol("features")
+      .setOutputCol("indexedFeatures").setMaxCategories(10)
+    stages += featuresIndexer
+    // (3) Learn DecisionTree
+    val dt = algo match {
+      case "classification" =>
+        new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
+          .setLabelCol(labelColName)
+          .setMaxDepth(params.maxDepth)
+          .setMaxBins(params.maxBins)
+          .setMinInstancesPerNode(params.minInstancesPerNode)
+          .setMinInfoGain(params.minInfoGain)
+          .setCacheNodeIds(params.cacheNodeIds)
+          .setCheckpointInterval(params.checkpointInterval)
+      case "regression" =>
+        new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
+          .setLabelCol(labelColName)
+          .setMaxDepth(params.maxDepth)
+          .setMaxBins(params.maxBins)
+          .setMinInstancesPerNode(params.minInstancesPerNode)
+          .setMinInfoGain(params.minInfoGain)
+          .setCacheNodeIds(params.cacheNodeIds)
+          .setCheckpointInterval(params.checkpointInterval)
+      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+    }
+    stages += dt
+    val pipeline = new Pipeline().setStages(stages.toArray)
+
+    // Fit the Pipeline
+    val startTime = System.nanoTime()
+    val pipelineModel = pipeline.fit(training)
+    val elapsedTime = (System.nanoTime() - startTime) / 1e9
+    println(s"Training time: $elapsedTime seconds")
+
+    // Get the trained Decision Tree from the fitted PipelineModel
+    val treeModel: DecisionTreeModel = algo match {
+      case "classification" =>
+        pipelineModel.getModel[DecisionTreeClassificationModel](
+          dt.asInstanceOf[DecisionTreeClassifier])
+      case "regression" =>
+        pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor])
+      case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+    }
+    if (treeModel.numNodes < 20) {
+      println(treeModel.toDebugString) // Print full model.
+    } else {
+      println(treeModel) // Print model summary.
+    }
+
+    // Predict on training
+    val trainingFullPredictions = pipelineModel.transform(training).cache()
+    val trainingPredictions = trainingFullPredictions.select("prediction")
+      .map(_.getDouble(0))
+    val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0))
+    // Predict on test data
+    val testFullPredictions = pipelineModel.transform(test).cache()
+    val testPredictions = testFullPredictions.select("prediction")
+      .map(_.getDouble(0))
+    val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0))
+
+    // For classification, print number of classes for reference.
+    if (algo == "classification") {
+      val numClasses =
+        MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match {
+          case Some(n) => n
+          case None => throw new RuntimeException(
+            "DecisionTreeExample had unknown failure when indexing labels for classification.")
+        }
+      println(s"numClasses = $numClasses.")
+    }
+
+    // Evaluate model on training, test data
+    algo match {
+      case "classification" =>
+        val trainingAccuracy =
+          new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision
+        println(s"Train accuracy = $trainingAccuracy")
+        val testAccuracy =
+          new MulticlassMetrics(testPredictions.zip(testLabels)).precision
+        println(s"Test accuracy = $testAccuracy")
+      case "regression" =>
+        val trainingRMSE =
+          new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError
+        println(s"Training root mean squared error (RMSE) = $trainingRMSE")
+        val testRMSE =
+          new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError
+        println(s"Test root mean squared error (RMSE) = $testRMSE")
+      case _ =>
+        throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+    }
+
+    sc.stop()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
index aa27a66..d7dee8f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
@@ -117,12 +117,12 @@ class AttributeGroup private (
         case numeric: NumericAttribute =>
           // Skip default numeric attributes.
           if (numeric.withoutIndex != NumericAttribute.defaultAttr) {
-            numericMetadata += numeric.toMetadata(withType = false)
+            numericMetadata += numeric.toMetadataImpl(withType = false)
           }
         case nominal: NominalAttribute =>
-          nominalMetadata += nominal.toMetadata(withType = false)
+          nominalMetadata += nominal.toMetadataImpl(withType = false)
         case binary: BinaryAttribute =>
-          binaryMetadata += binary.toMetadata(withType = false)
+          binaryMetadata += binary.toMetadataImpl(withType = false)
       }
       val attrBldr = new MetadataBuilder
       if (numericMetadata.nonEmpty) {
@@ -151,7 +151,7 @@ class AttributeGroup private (
   }
 
   /** Converts to ML metadata */
-  def toMetadata: Metadata = toMetadata(Metadata.empty)
+  def toMetadata(): Metadata = toMetadata(Metadata.empty)
 
   /** Converts to a StructField with some existing metadata. */
   def toStructField(existingMetadata: Metadata): StructField = {
@@ -159,7 +159,7 @@ class AttributeGroup private (
   }
 
   /** Converts to a StructField. */
-  def toStructField: StructField = toStructField(Metadata.empty)
+  def toStructField(): StructField = toStructField(Metadata.empty)
 
   override def equals(other: Any): Boolean = {
     other match {

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index 00b7566..5717d6e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -68,21 +68,32 @@ sealed abstract class Attribute extends Serializable {
    * Converts this attribute to [[Metadata]].
    * @param withType whether to include the type info
    */
-  private[attribute] def toMetadata(withType: Boolean): Metadata
+  private[attribute] def toMetadataImpl(withType: Boolean): Metadata
 
   /**
    * Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to
    * save space, because numeric type is the default attribute type. For nominal and binary
    * attributes, the type info is included.
    */
-  private[attribute] def toMetadata(): Metadata = {
+  private[attribute] def toMetadataImpl(): Metadata = {
     if (attrType == AttributeType.Numeric) {
-      toMetadata(withType = false)
+      toMetadataImpl(withType = false)
     } else {
-      toMetadata(withType = true)
+      toMetadataImpl(withType = true)
     }
   }
 
+  /** Converts to ML metadata with some existing metadata. */
+  def toMetadata(existingMetadata: Metadata): Metadata = {
+    new MetadataBuilder()
+      .withMetadata(existingMetadata)
+      .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl())
+      .build()
+  }
+
+  /** Converts to ML metadata */
+  def toMetadata(): Metadata = toMetadata(Metadata.empty)
+
   /**
    * Converts to a [[StructField]] with some existing metadata.
    * @param existingMetadata existing metadata to carry over
@@ -90,7 +101,7 @@ sealed abstract class Attribute extends Serializable {
   def toStructField(existingMetadata: Metadata): StructField = {
     val newMetadata = new MetadataBuilder()
       .withMetadata(existingMetadata)
-      .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadata())
+      .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl())
       .build()
     StructField(name.get, DoubleType, nullable = false, newMetadata)
   }
@@ -98,7 +109,7 @@ sealed abstract class Attribute extends Serializable {
   /** Converts to a [[StructField]]. */
   def toStructField(): StructField = toStructField(Metadata.empty)
 
-  override def toString: String = toMetadata(withType = true).toString
+  override def toString: String = toMetadataImpl(withType = true).toString
 }
 
 /** Trait for ML attribute factories. */
@@ -210,7 +221,7 @@ class NumericAttribute private[ml] (
   override def isNominal: Boolean = false
 
   /** Convert this attribute to metadata. */
-  private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+  override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
     import org.apache.spark.ml.attribute.AttributeKeys._
     val bldr = new MetadataBuilder()
     if (withType) bldr.putString(TYPE, attrType.name)
@@ -353,6 +364,20 @@ class NominalAttribute private[ml] (
   /** Copy without the `numValues`. */
   def withoutNumValues: NominalAttribute = copy(numValues = None)
 
+  /**
+   * Get the number of values, either from `numValues` or from `values`.
+   * Return None if unknown.
+   */
+  def getNumValues: Option[Int] = {
+    if (numValues.nonEmpty) {
+      numValues
+    } else if (values.nonEmpty) {
+      Some(values.get.length)
+    } else {
+      None
+    }
+  }
+
   /** Creates a copy of this attribute with optional changes. */
   private def copy(
       name: Option[String] = name,
@@ -363,7 +388,7 @@ class NominalAttribute private[ml] (
     new NominalAttribute(name, index, isOrdinal, numValues, values)
   }
 
-  private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+  override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
     import org.apache.spark.ml.attribute.AttributeKeys._
     val bldr = new MetadataBuilder()
     if (withType) bldr.putString(TYPE, attrType.name)
@@ -465,7 +490,7 @@ class BinaryAttribute private[ml] (
     new BinaryAttribute(name, index, values)
   }
 
-  private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+  override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
     import org.apache.spark.ml.attribute.AttributeKeys._
     val bldr = new MetadataBuilder
     if (withType) bldr.putString(TYPE, attrType.name)

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
new file mode 100644
index 0000000..3855e39
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
+ * for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class DecisionTreeClassifier
+  extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
+  with DecisionTreeParams
+  with TreeClassifierParams {
+
+  // Override parameter setters from parent trait for Java API compatibility.
+
+  override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+  override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+  override def setMinInstancesPerNode(value: Int): this.type =
+    super.setMinInstancesPerNode(value)
+
+  override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+  override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+  override def setCacheNodeIds(value: Boolean): this.type =
+    super.setCacheNodeIds(value)
+
+  override def setCheckpointInterval(value: Int): this.type =
+    super.setCheckpointInterval(value)
+
+  override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+  override protected def train(
+      dataset: DataFrame,
+      paramMap: ParamMap): DecisionTreeClassificationModel = {
+    val categoricalFeatures: Map[Int, Int] =
+      MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+    val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+      case Some(n: Int) => n
+      case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
+        s" with invalid label column, without the number of classes specified.")
+        // TODO: Automatically index labels.
+    }
+    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+    val strategy = getOldStrategy(categoricalFeatures, numClasses)
+    val oldModel = OldDecisionTree.train(oldDataset, strategy)
+    DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+  }
+
+  /** (private[ml]) Create a Strategy instance to use with the old API. */
+  override private[ml] def getOldStrategy(
+      categoricalFeatures: Map[Int, Int],
+      numClasses: Int): OldStrategy = {
+    val strategy = super.getOldStrategy(categoricalFeatures, numClasses)
+    strategy.algo = OldAlgo.Classification
+    strategy.setImpurity(getOldImpurity)
+    strategy
+  }
+}
+
+object DecisionTreeClassifier {
+  /** Accessor for supported impurities */
+  final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class DecisionTreeClassificationModel private[ml] (
+    override val parent: DecisionTreeClassifier,
+    override val fittingParamMap: ParamMap,
+    override val rootNode: Node)
+  extends PredictionModel[Vector, DecisionTreeClassificationModel]
+  with DecisionTreeModel with Serializable {
+
+  require(rootNode != null,
+    "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+  override protected def predict(features: Vector): Double = {
+    rootNode.predict(features)
+  }
+
+  override protected def copy(): DecisionTreeClassificationModel = {
+    val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+    Params.inheritValues(this.extractParamMap(), this, m)
+    m
+  }
+
+  override def toString: String = {
+    s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes"
+  }
+
+  /** (private[ml]) Convert to a model in the old API */
+  private[ml] def toOld: OldDecisionTreeModel = {
+    new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
+  }
+}
+
+private[ml] object DecisionTreeClassificationModel {
+
+  /** (private[ml]) Convert a model from the old API */
+  def fromOld(
+      oldModel: OldDecisionTreeModel,
+      parent: DecisionTreeClassifier,
+      fittingParamMap: ParamMap,
+      categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
+    require(oldModel.algo == OldAlgo.Classification,
+      s"Cannot convert non-classification DecisionTreeModel (old API) to" +
+        s" DecisionTreeClassificationModel (new API).  Algo is: ${oldModel.algo}")
+    val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+    new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 4d960df..23956c5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -118,7 +118,7 @@ class StringIndexerModel private[ml] (
     }
     val outputColName = map(outputCol)
     val metadata = NominalAttribute.defaultAttr
-      .withName(outputColName).withValues(labels).toStructField().metadata
+      .withName(outputColName).withValues(labels).toMetadata()
     dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
new file mode 100644
index 0000000..6f4509f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl.tree
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.impl.estimator.PredictorParams
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
+  Impurity => OldImpurity, Variance => OldVariance}
+
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait DecisionTreeParams extends PredictorParams {
+
+  /**
+   * Maximum depth of the tree.
+   * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+   * (default = 5)
+   * @group param
+   */
+  final val maxDepth: IntParam =
+    new IntParam(this, "maxDepth", "Maximum depth of the tree." +
+      " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
+
+  /**
+   * Maximum number of bins used for discretizing continuous features and for choosing how to split
+   * on features at each node.  More bins give higher granularity.
+   * Must be >= 2 and >= number of categories in any categorical feature.
+   * (default = 32)
+   * @group param
+   */
+  final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
+    " discretizing continuous features.  Must be >=2 and >= number of categories for any" +
+    " categorical feature.")
+
+  /**
+   * Minimum number of instances each child must have after split.
+   * If a split causes the left or right child to have fewer than minInstancesPerNode,
+   * the split will be discarded as invalid.
+   * Should be >= 1.
+   * (default = 1)
+   * @group param
+   */
+  final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
+    " number of instances each child must have after split.  If a split causes the left or right" +
+    " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
+    " Should be >= 1.")
+
+  /**
+   * Minimum information gain for a split to be considered at a tree node.
+   * (default = 0.0)
+   * @group param
+   */
+  final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
+    "Minimum information gain for a split to be considered at a tree node.")
+
+  /**
+   * Maximum memory in MB allocated to histogram aggregation.
+   * (default = 256 MB)
+   * @group expertParam
+   */
+  final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
+    "Maximum memory in MB allocated to histogram aggregation.")
+
+  /**
+   * If false, the algorithm will pass trees to executors to match instances with nodes.
+   * If true, the algorithm will cache node IDs for each instance.
+   * Caching can speed up training of deeper trees.
+   * (default = false)
+   * @group expertParam
+   */
+  final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
+    " algorithm will pass trees to executors to match instances with nodes. If true, the" +
+    " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
+    " trees.")
+
+  /**
+   * Specifies how often to checkpoint the cached node IDs.
+   * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+   * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+   * [[org.apache.spark.SparkContext]].
+   * Must be >= 1.
+   * (default = 10)
+   * @group expertParam
+   */
+  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
+    " how often to checkpoint the cached node IDs.  E.g. 10 means that the cache will get" +
+    " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
+    " checkpoint directory is set in the SparkContext. Must be >= 1.")
+
+  setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
+    maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+
+  /** @group setParam */
+  def setMaxDepth(value: Int): this.type = {
+    require(value >= 0, s"maxDepth parameter must be >= 0.  Given bad value: $value")
+    set(maxDepth, value)
+    this.asInstanceOf[this.type]
+  }
+
+  /** @group getParam */
+  def getMaxDepth: Int = getOrDefault(maxDepth)
+
+  /** @group setParam */
+  def setMaxBins(value: Int): this.type = {
+    require(value >= 2, s"maxBins parameter must be >= 2.  Given bad value: $value")
+    set(maxBins, value)
+    this
+  }
+
+  /** @group getParam */
+  def getMaxBins: Int = getOrDefault(maxBins)
+
+  /** @group setParam */
+  def setMinInstancesPerNode(value: Int): this.type = {
+    require(value >= 1, s"minInstancesPerNode parameter must be >= 1.  Given bad value: $value")
+    set(minInstancesPerNode, value)
+    this
+  }
+
+  /** @group getParam */
+  def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
+
+  /** @group setParam */
+  def setMinInfoGain(value: Double): this.type = {
+    set(minInfoGain, value)
+    this
+  }
+
+  /** @group getParam */
+  def getMinInfoGain: Double = getOrDefault(minInfoGain)
+
+  /** @group expertSetParam */
+  def setMaxMemoryInMB(value: Int): this.type = {
+    require(value > 0, s"maxMemoryInMB parameter must be > 0.  Given bad value: $value")
+    set(maxMemoryInMB, value)
+    this
+  }
+
+  /** @group expertGetParam */
+  def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
+
+  /** @group expertSetParam */
+  def setCacheNodeIds(value: Boolean): this.type = {
+    set(cacheNodeIds, value)
+    this
+  }
+
+  /** @group expertGetParam */
+  def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
+
+  /** @group expertSetParam */
+  def setCheckpointInterval(value: Int): this.type = {
+    require(value >= 1, s"checkpointInterval parameter must be >= 1.  Given bad value: $value")
+    set(checkpointInterval, value)
+    this
+  }
+
+  /** @group expertGetParam */
+  def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+
+  /**
+   * Create a Strategy instance to use with the old API.
+   * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0,
+   *       the default for single trees).
+   */
+  private[ml] def getOldStrategy(
+      categoricalFeatures: Map[Int, Int],
+      numClasses: Int): OldStrategy = {
+    val strategy = OldStrategy.defaultStategy(OldAlgo.Classification)
+    strategy.checkpointInterval = getCheckpointInterval
+    strategy.maxBins = getMaxBins
+    strategy.maxDepth = getMaxDepth
+    strategy.maxMemoryInMB = getMaxMemoryInMB
+    strategy.minInfoGain = getMinInfoGain
+    strategy.minInstancesPerNode = getMinInstancesPerNode
+    strategy.useNodeIdCache = getCacheNodeIds
+    strategy.numClasses = numClasses
+    strategy.categoricalFeaturesInfo = categoricalFeatures
+    strategy.subsamplingRate = 1.0 // default for individual trees
+    strategy
+  }
+}
+
+/**
+ * (private trait) Parameters for Decision Tree-based classification algorithms.
+ */
+private[ml] trait TreeClassifierParams extends Params {
+
+  /**
+   * Criterion used for information gain calculation (case-insensitive).
+   * Supported: "entropy" and "gini".
+   * (default = gini)
+   * @group param
+   */
+  val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+    " information gain calculation (case-insensitive). Supported options:" +
+    s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+
+  setDefault(impurity -> "gini")
+
+  /** @group setParam */
+  def setImpurity(value: String): this.type = {
+    val impurityStr = value.toLowerCase
+    require(TreeClassifierParams.supportedImpurities.contains(impurityStr),
+      s"Tree-based classifier was given unrecognized impurity: $value." +
+      s"  Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+    set(impurity, impurityStr)
+    this
+  }
+
+  /** @group getParam */
+  def getImpurity: String = getOrDefault(impurity)
+
+  /** Convert new impurity to old impurity. */
+  private[ml] def getOldImpurity: OldImpurity = {
+    getImpurity match {
+      case "entropy" => OldEntropy
+      case "gini" => OldGini
+      case _ =>
+        // Should never happen because of check in setter method.
+        throw new RuntimeException(
+          s"TreeClassifierParams was given unrecognized impurity: $impurity.")
+    }
+  }
+}
+
+private[ml] object TreeClassifierParams {
+  // These options should be lowercase.
+  val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+}
+
+/**
+ * (private trait) Parameters for Decision Tree-based regression algorithms.
+ */
+private[ml] trait TreeRegressorParams extends Params {
+
+  /**
+   * Criterion used for information gain calculation (case-insensitive).
+   * Supported: "variance".
+   * (default = variance)
+   * @group param
+   */
+  val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+    " information gain calculation (case-insensitive). Supported options:" +
+    s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+
+  setDefault(impurity -> "variance")
+
+  /** @group setParam */
+  def setImpurity(value: String): this.type = {
+    val impurityStr = value.toLowerCase
+    require(TreeRegressorParams.supportedImpurities.contains(impurityStr),
+      s"Tree-based regressor was given unrecognized impurity: $value." +
+        s"  Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+    set(impurity, impurityStr)
+    this
+  }
+
+  /** @group getParam */
+  def getImpurity: String = getOrDefault(impurity)
+
+  /** Convert new impurity to old impurity. */
+  protected def getOldImpurity: OldImpurity = {
+    getImpurity match {
+      case "variance" => OldVariance
+      case _ =>
+        // Should never happen because of check in setter method.
+        throw new RuntimeException(
+          s"TreeRegressorParams was given unrecognized impurity: $impurity")
+    }
+  }
+}
+
+private[ml] object TreeRegressorParams {
+  // These options should be lowercase.
+  val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/package.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
index b45bd14..ac75e9d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -32,6 +32,18 @@ package org.apache.spark
  * @groupname getParam Parameter getters
  * @groupprio getParam 6
  *
+ * @groupname expertParam (expert-only) Parameters
+ * @groupdesc expertParam A list of advanced, expert-only (hyper-)parameter keys this algorithm can
+ *            take. Users can set and get the parameter values through setters and getters,
+ *            respectively.
+ * @groupprio expertParam 7
+ *
+ * @groupname expertSetParam (expert-only) Parameter setters
+ * @groupprio expertSetParam 8
+ *
+ * @groupname expertGetParam (expert-only) Parameter getters
+ * @groupprio expertGetParam 9
+ *
  * @groupname Ungrouped Members
  * @groupprio Ungrouped 0
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 849c604..ddc5907 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -296,8 +296,9 @@ private[spark] object Params {
       paramMap: ParamMap,
       parent: E,
       child: M): Unit = {
+    val childParams = child.params.map(_.name).toSet
     parent.params.foreach { param =>
-      if (paramMap.contains(param)) {
+      if (paramMap.contains(param) && childParams.contains(param.name)) {
         child.set(child.getParam(param.name), paramMap(param))
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
new file mode 100644
index 0000000..49a8b77
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
+ * for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class DecisionTreeRegressor
+  extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
+  with DecisionTreeParams
+  with TreeRegressorParams {
+
+  // Override parameter setters from parent trait for Java API compatibility.
+
+  override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+  override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+  override def setMinInstancesPerNode(value: Int): this.type =
+    super.setMinInstancesPerNode(value)
+
+  override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+  override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+  override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+  override def setCheckpointInterval(value: Int): this.type =
+    super.setCheckpointInterval(value)
+
+  override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+  override protected def train(
+      dataset: DataFrame,
+      paramMap: ParamMap): DecisionTreeRegressionModel = {
+    val categoricalFeatures: Map[Int, Int] =
+      MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+    val strategy = getOldStrategy(categoricalFeatures)
+    val oldModel = OldDecisionTree.train(oldDataset, strategy)
+    DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+  }
+
+  /** (private[ml]) Create a Strategy instance to use with the old API. */
+  private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
+    val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0)
+    strategy.algo = OldAlgo.Regression
+    strategy.setImpurity(getOldImpurity)
+    strategy
+  }
+}
+
+object DecisionTreeRegressor {
+  /** Accessor for supported impurities */
+  final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression.
+ * It supports both continuous and categorical features.
+ * @param rootNode  Root of the decision tree
+ */
+@AlphaComponent
+final class DecisionTreeRegressionModel private[ml] (
+    override val parent: DecisionTreeRegressor,
+    override val fittingParamMap: ParamMap,
+    override val rootNode: Node)
+  extends PredictionModel[Vector, DecisionTreeRegressionModel]
+  with DecisionTreeModel with Serializable {
+
+  require(rootNode != null,
+    "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+  override protected def predict(features: Vector): Double = {
+    rootNode.predict(features)
+  }
+
+  override protected def copy(): DecisionTreeRegressionModel = {
+    val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+    Params.inheritValues(this.extractParamMap(), this, m)
+    m
+  }
+
+  override def toString: String = {
+    s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes"
+  }
+
+  /** Convert to a model in the old API */
+  private[ml] def toOld: OldDecisionTreeModel = {
+    new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
+  }
+}
+
+private[ml] object DecisionTreeRegressionModel {
+
+  /** (private[ml]) Convert a model from the old API */
+  def fromOld(
+      oldModel: OldDecisionTreeModel,
+      parent: DecisionTreeRegressor,
+      fittingParamMap: ParamMap,
+      categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = {
+    require(oldModel.algo == OldAlgo.Regression,
+      s"Cannot convert non-regression DecisionTreeModel (old API) to" +
+        s" DecisionTreeRegressionModel (new API).  Algo is: ${oldModel.algo}")
+    val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+    new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
new file mode 100644
index 0000000..d6e2203
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
+  Node => OldNode, Predict => OldPredict}
+
+
+/**
+ * Decision tree node interface.
+ */
+sealed abstract class Node extends Serializable {
+
+  // TODO: Add aggregate stats (once available).  This will happen after we move the DecisionTree
+  //       code into the new API and deprecate the old API.
+
+  /** Prediction this node makes (or would make, if it is an internal node) */
+  def prediction: Double
+
+  /** Impurity measure at this node (for training data) */
+  def impurity: Double
+
+  /** Recursive prediction helper method */
+  private[ml] def predict(features: Vector): Double = prediction
+
+  /**
+   * Get the number of nodes in tree below this node, including leaf nodes.
+   * E.g., if this is a leaf, returns 0.  If both children are leaves, returns 2.
+   */
+  private[tree] def numDescendants: Int
+
+  /**
+   * Recursive print function.
+   * @param indentFactor  The number of spaces to add to each level of indentation.
+   */
+  private[tree] def subtreeToString(indentFactor: Int = 0): String
+
+  /**
+   * Get depth of tree from this node.
+   * E.g.: Depth 0 means this is a leaf node.  Depth 1 means 1 internal and 2 leaf nodes.
+   */
+  private[tree] def subtreeDepth: Int
+
+  /**
+   * Create a copy of this node in the old Node format, recursively creating child nodes as needed.
+   * @param id  Node ID using old format IDs
+   */
+  private[ml] def toOld(id: Int): OldNode
+}
+
+private[ml] object Node {
+
+  /**
+   * Create a new Node from the old Node format, recursively creating child nodes as needed.
+   */
+  def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
+    if (oldNode.isLeaf) {
+      // TODO: Once the implementation has been moved to this API, then include sufficient
+      //       statistics here.
+      new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
+    } else {
+      val gain = if (oldNode.stats.nonEmpty) {
+        oldNode.stats.get.gain
+      } else {
+        0.0
+      }
+      new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
+        gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
+        rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
+        split = Split.fromOld(oldNode.split.get, categoricalFeatures))
+    }
+  }
+}
+
+/**
+ * Decision tree leaf node.
+ * @param prediction  Prediction this node makes
+ * @param impurity  Impurity measure at this node (for training data)
+ */
+final class LeafNode private[ml] (
+    override val prediction: Double,
+    override val impurity: Double) extends Node {
+
+  override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
+
+  override private[ml] def predict(features: Vector): Double = prediction
+
+  override private[tree] def numDescendants: Int = 0
+
+  override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+    val prefix: String = " " * indentFactor
+    prefix + s"Predict: $prediction\n"
+  }
+
+  override private[tree] def subtreeDepth: Int = 0
+
+  override private[ml] def toOld(id: Int): OldNode = {
+    // NOTE: We do NOT store 'prob' in the new API currently.
+    new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
+      None, None, None, None)
+  }
+}
+
+/**
+ * Internal Decision Tree node.
+ * @param prediction  Prediction this node would make if it were a leaf node
+ * @param impurity  Impurity measure at this node (for training data)
+ * @param gain Information gain value.
+ *             Values < 0 indicate missing values; this quirk will be removed with future updates.
+ * @param leftChild  Left-hand child node
+ * @param rightChild  Right-hand child node
+ * @param split  Information about the test used to split to the left or right child.
+ */
+final class InternalNode private[ml] (
+    override val prediction: Double,
+    override val impurity: Double,
+    val gain: Double,
+    val leftChild: Node,
+    val rightChild: Node,
+    val split: Split) extends Node {
+
+  override def toString: String = {
+    s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
+  }
+
+  override private[ml] def predict(features: Vector): Double = {
+    if (split.shouldGoLeft(features)) {
+      leftChild.predict(features)
+    } else {
+      rightChild.predict(features)
+    }
+  }
+
+  override private[tree] def numDescendants: Int = {
+    2 + leftChild.numDescendants + rightChild.numDescendants
+  }
+
+  override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+    val prefix: String = " " * indentFactor
+    prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" +
+      leftChild.subtreeToString(indentFactor + 1) +
+      prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" +
+      rightChild.subtreeToString(indentFactor + 1)
+  }
+
+  override private[tree] def subtreeDepth: Int = {
+    1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth)
+  }
+
+  override private[ml] def toOld(id: Int): OldNode = {
+    assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+      + " since the old API does not support deep trees.")
+    // NOTE: We do NOT store 'prob' in the new API currently.
+    new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
+      Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
+      Some(rightChild.toOld(OldNode.rightChildIndex(id))),
+      Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
+        new OldPredict(leftChild.prediction, prob = 0.0),
+        new OldPredict(rightChild.prediction, prob = 0.0))))
+  }
+}
+
+private object InternalNode {
+
+  /**
+   * Helper method for [[Node.subtreeToString()]].
+   * @param split  Split to print
+   * @param left  Indicates whether this is the part of the split going to the left,
+   *              or that going to the right.
+   */
+  private def splitToString(split: Split, left: Boolean): String = {
+    val featureStr = s"feature ${split.featureIndex}"
+    split match {
+      case contSplit: ContinuousSplit =>
+        if (left) {
+          s"$featureStr <= ${contSplit.threshold}"
+        } else {
+          s"$featureStr > ${contSplit.threshold}"
+        }
+      case catSplit: CategoricalSplit =>
+        val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}")
+        if (left) {
+          s"$featureStr in $categoriesStr"
+        } else {
+          s"$featureStr not in $categoriesStr"
+        }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
new file mode 100644
index 0000000..cb940f6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
+import org.apache.spark.mllib.tree.model.{Split => OldSplit}
+
+
+/**
+ * Interface for a "Split," which specifies a test made at a decision tree node
+ * to choose the left or right path.
+ */
+sealed trait Split extends Serializable {
+
+  /** Index of feature which this split tests */
+  def featureIndex: Int
+
+  /** Return true (split to left) or false (split to right) */
+  private[ml] def shouldGoLeft(features: Vector): Boolean
+
+  /** Convert to old Split format */
+  private[tree] def toOld: OldSplit
+}
+
+private[ml] object Split {
+
+  def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
+    oldSplit.featureType match {
+      case OldFeatureType.Categorical =>
+        new CategoricalSplit(featureIndex = oldSplit.feature,
+          leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
+      case OldFeatureType.Continuous =>
+        new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold)
+    }
+  }
+}
+
+/**
+ * Split which tests a categorical feature.
+ * @param featureIndex  Index of the feature to test
+ * @param leftCategories  If the feature value is in this set of categories, then the split goes
+ *                        left. Otherwise, it goes right.
+ * @param numCategories  Number of categories for this feature.
+ */
+final class CategoricalSplit(
+    override val featureIndex: Int,
+    leftCategories: Array[Double],
+    private val numCategories: Int)
+  extends Split {
+
+  require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
+    s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}")
+
+  /**
+   * If true, then "categories" is the set of categories for splitting to the left, and vice versa.
+   */
+  private val isLeft: Boolean = leftCategories.length <= numCategories / 2
+
+  /** Set of categories determining the splitting rule, along with [[isLeft]]. */
+  private val categories: Set[Double] = {
+    if (isLeft) {
+      leftCategories.toSet
+    } else {
+      setComplement(leftCategories.toSet)
+    }
+  }
+
+  override private[ml] def shouldGoLeft(features: Vector): Boolean = {
+    if (isLeft) {
+      categories.contains(features(featureIndex))
+    } else {
+      !categories.contains(features(featureIndex))
+    }
+  }
+
+  override def equals(o: Any): Boolean = {
+    o match {
+      case other: CategoricalSplit => featureIndex == other.featureIndex &&
+        isLeft == other.isLeft && categories == other.categories
+      case _ => false
+    }
+  }
+
+  override private[tree] def toOld: OldSplit = {
+    val oldCats = if (isLeft) {
+      categories
+    } else {
+      setComplement(categories)
+    }
+    OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList)
+  }
+
+  /** Get sorted categories which split to the left */
+  def getLeftCategories: Array[Double] = {
+    val cats = if (isLeft) categories else setComplement(categories)
+    cats.toArray.sorted
+  }
+
+  /** Get sorted categories which split to the right */
+  def getRightCategories: Array[Double] = {
+    val cats = if (isLeft) setComplement(categories) else categories
+    cats.toArray.sorted
+  }
+
+  /** [0, numCategories) \ cats */
+  private def setComplement(cats: Set[Double]): Set[Double] = {
+    Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet
+  }
+}
+
+/**
+ * Split which tests a continuous feature.
+ * @param featureIndex  Index of the feature to test
+ * @param threshold  If the feature value is <= this threshold, then the split goes left.
+ *                    Otherwise, it goes right.
+ */
+final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
+
+  override private[ml] def shouldGoLeft(features: Vector): Boolean = {
+    features(featureIndex) <= threshold
+  }
+
+  override def equals(o: Any): Boolean = {
+    o match {
+      case other: ContinuousSplit =>
+        featureIndex == other.featureIndex && threshold == other.threshold
+      case _ =>
+        false
+    }
+  }
+
+  override private[tree] def toOld: OldSplit = {
+    OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double])
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
new file mode 100644
index 0000000..8e3bc38
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.annotation.AlphaComponent
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for Decision Tree models.
+ *
+ * TODO: Add support for predicting probabilities and raw predictions
+ */
+@AlphaComponent
+trait DecisionTreeModel {
+
+  /** Root of the decision tree */
+  def rootNode: Node
+
+  /** Number of nodes in tree, including leaf nodes. */
+  def numNodes: Int = {
+    1 + rootNode.numDescendants
+  }
+
+  /**
+   * Depth of the tree.
+   * E.g.: Depth 0 means 1 leaf node.  Depth 1 means 1 internal node and 2 leaf nodes.
+   */
+  lazy val depth: Int = {
+    rootNode.subtreeDepth
+  }
+
+  /** Summary of the model */
+  override def toString: String = {
+    // Implementing classes should generally override this method to be more descriptive.
+    s"DecisionTreeModel of depth $depth with $numNodes nodes"
+  }
+
+  /** Full description of model */
+  def toDebugString: String = {
+    val header = toString + "\n"
+    header + rootNode.subtreeToString(2)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
new file mode 100644
index 0000000..c84c8b4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute,
+  NumericAttribute}
+import org.apache.spark.sql.types.StructField
+
+
+/**
+ * :: Experimental ::
+ *
+ * Helper utilities for tree-based algorithms
+ */
+@Experimental
+object MetadataUtils {
+
+  /**
+   * Examine a schema to identify the number of classes in a label column.
+   * Returns None if the number of labels is not specified, or if the label column is continuous.
+   */
+  def getNumClasses(labelSchema: StructField): Option[Int] = {
+    Attribute.fromStructField(labelSchema) match {
+      case numAttr: NumericAttribute => None
+      case binAttr: BinaryAttribute => Some(2)
+      case nomAttr: NominalAttribute => nomAttr.getNumValues
+    }
+  }
+
+  /**
+   * Examine a schema to identify categorical (Binary and Nominal) features.
+   *
+   * @param featuresSchema  Schema of the features column.
+   *                        If a feature does not have metadata, it is assumed to be continuous.
+   *                        If a feature is Nominal, then it must have the number of values
+   *                        specified.
+   * @return  Map: feature index --> number of categories.
+   *          The map's set of keys will be the set of categorical feature indices.
+   */
+  def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = {
+    val metadata = AttributeGroup.fromStructField(featuresSchema)
+    if (metadata.attributes.isEmpty) {
+      HashMap.empty[Int, Int]
+    } else {
+      metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) =>
+        if (attr == null) {
+          Iterator()
+        } else {
+          attr match {
+            case numAttr: NumericAttribute => Iterator()
+            case binAttr: BinaryAttribute => Iterator(idx -> 2)
+            case nomAttr: NominalAttribute =>
+              nomAttr.getNumValues match {
+                case Some(numValues: Int) => Iterator(idx -> numValues)
+                case None => throw new IllegalArgumentException(s"Feature $idx is marked as" +
+                  " Nominal (categorical), but it does not have the number of values specified.")
+              }
+          }
+        }
+      }.toMap
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b9d0c56..dfe3a0b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -1147,7 +1147,10 @@ object DecisionTree extends Serializable with Logging {
       }
     }
 
-    assert(splits.length > 0)
+    // TODO: Do not fail; just ignore the useless feature.
+    assert(splits.length > 0,
+      s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
+        "  Please remove this feature and then try again.")
     // set number of splits accordingly
     metadata.setNumSplits(featureIndex, splits.length)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index c02c79f..0e31c7e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -81,11 +81,11 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
   /**
    * Method to validate a gradient boosting model
    * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-   * @param validationInput Validation dataset:
-                          RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-                          Should be different from and follow the same distribution as input.
-                          e.g., these two datasets could be created from an original dataset
-                          by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+   * @param validationInput Validation dataset.
+   *                        This dataset should be different from the training dataset,
+   *                        but it should follow the same distribution.
+   *                        E.g., these two datasets could be created from an original dataset
+   *                        by using [[org.apache.spark.rdd.RDD.randomSplit()]]
    * @return a gradient boosted trees model that can be used for prediction
    */
   def runWithValidation(
@@ -194,8 +194,6 @@ object GradientBoostedTrees extends Logging {
     val firstTreeWeight = 1.0
     baseLearners(0) = firstTreeModel
     baseLearnerWeights(0) = firstTreeWeight
-    val startingModel = new GradientBoostedTreesModel(
-      Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))
 
     var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
       computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index db01f2e..055e60c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -249,7 +249,7 @@ private class RandomForest (
         nodeIdCache.get.deleteAllCheckpoints()
       } catch {
         case e:IOException =>
-          logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}")
+          logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 664c8df..2d6b015 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -89,14 +89,14 @@ object BoostingStrategy {
    * @return Configuration for boosting algorithm
    */
   def defaultParams(algo: Algo): BoostingStrategy = {
-    val treeStragtegy = Strategy.defaultStategy(algo)
-    treeStragtegy.maxDepth = 3
+    val treeStrategy = Strategy.defaultStategy(algo)
+    treeStrategy.maxDepth = 3
     algo match {
       case Algo.Classification =>
-        treeStragtegy.numClasses = 2
-        new BoostingStrategy(treeStragtegy, LogLoss)
+        treeStrategy.numClasses = 2
+        new BoostingStrategy(treeStrategy, LogLoss)
       case Algo.Regression =>
-        new BoostingStrategy(treeStragtegy, SquaredError)
+        new BoostingStrategy(treeStrategy, SquaredError)
       case _ =>
         throw new IllegalArgumentException(s"$algo is not supported by boosting.")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index 6f570b4..2bdef73 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.model.TreeEnsembleModel
-import org.apache.spark.rdd.RDD
+
 
 /**
  * :: DeveloperApi ::
@@ -45,9 +45,8 @@ object AbsoluteError extends Loss {
     if (label - prediction < 0) 1.0 else -1.0
   }
 
-  override def computeError(prediction: Double, label: Double): Double = {
+  override private[mllib] def computeError(prediction: Double, label: Double): Double = {
     val err = label - prediction
     math.abs(err)
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 24ee9f3..778c245 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.model.TreeEnsembleModel
 import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
+
 
 /**
  * :: DeveloperApi ::
@@ -47,10 +47,9 @@ object LogLoss extends Loss {
     - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
   }
 
-  override def computeError(prediction: Double, label: Double): Double = {
+  override private[mllib] def computeError(prediction: Double, label: Double): Double = {
     val margin = 2.0 * label * prediction
     // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
     2.0 * MLUtils.log1pExp(-margin)
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index d3b82b7..64ffccb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -22,6 +22,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.model.TreeEnsembleModel
 import org.apache.spark.rdd.RDD
 
+
 /**
  * :: DeveloperApi ::
  * Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
@@ -57,6 +58,5 @@ trait Loss extends Serializable {
    * @param label True label.
    * @return Measure of model error on datapoint.
    */
-  def computeError(prediction: Double, label: Double): Double
-
+  private[mllib] def computeError(prediction: Double, label: Double): Double
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 58857ae..a5582d3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.model.TreeEnsembleModel
-import org.apache.spark.rdd.RDD
+
 
 /**
  * :: DeveloperApi ::
@@ -45,9 +45,8 @@ object SquaredError extends Loss {
     2.0 * (prediction - label)
   }
 
-  override def computeError(prediction: Double, label: Double): Double = {
+  override private[mllib] def computeError(prediction: Double, label: Double): Double = {
     val err = prediction - label
     err * err
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index c9bafd6..331af42 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -113,11 +113,13 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
     DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
   }
 
-  override protected def formatVersion: String = "1.0"
+  override protected def formatVersion: String = DecisionTreeModel.formatVersion
 }
 
 object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
 
+  private[spark] def formatVersion: String = "1.0"
+
   private[tree] object SaveLoadV1_0 {
 
     def thisFormatVersion: String = "1.0"

http://git-wip-us.apache.org/repos/asf/spark/blob/a83571ac/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 4f72bb8..708ba04 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -175,7 +175,7 @@ class Node (
   }
 }
 
-private[tree] object Node {
+private[spark] object Node {
 
   /**
    * Return a node with the given node id (but nothing else set).


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org