You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/04/22 06:44:49 UTC

spark git commit: [SPARK-6113] [ML] Small cleanups after original tree API PR

Repository: spark
Updated Branches:
  refs/heads/master 70f9f8ff3 -> 607eff0ed


[SPARK-6113] [ML] Small cleanups after original tree API PR

This does a few clean-ups.  With this PR, all spark.ml tree components have ```private[ml]``` constructors.

CC: mengxr

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #5567 from jkbradley/dt-api-dt2 and squashes the following commits:

2263b5b [Joseph K. Bradley] Added note about tree example issue.
bb9f610 [Joseph K. Bradley] Small cleanups after original tree API PR


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/607eff0e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/607eff0e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/607eff0e

Branch: refs/heads/master
Commit: 607eff0edfc10a1473fa9713a0500bf09f105c13
Parents: 70f9f8f
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Tue Apr 21 21:44:44 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Apr 21 21:44:44 2015 -0700

----------------------------------------------------------------------
 .../spark/examples/ml/DecisionTreeExample.scala | 25 +++++++++++++++-----
 .../apache/spark/ml/impl/tree/treeParams.scala  |  4 ++--
 .../scala/org/apache/spark/ml/tree/Split.scala  |  7 +++---
 3 files changed, 25 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/607eff0e/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 921b396..2cd515c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -44,6 +44,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame}
  * {{{
  * ./bin/run-example ml.DecisionTreeExample [options]
  * }}}
+ * Note that Decision Trees can take a large amount of memory.  If the run-example command above
+ * fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g
+ *   [examples JAR path] [options]
+ * }}}
  * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
  */
 object DecisionTreeExample {
@@ -70,7 +77,7 @@ object DecisionTreeExample {
     val parser = new OptionParser[Params]("DecisionTreeExample") {
       head("DecisionTreeExample: an example decision tree app.")
       opt[String]("algo")
-        .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
+        .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
         .action((x, c) => c.copy(algo = x))
       opt[Int]("maxDepth")
         .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
@@ -222,18 +229,23 @@ object DecisionTreeExample {
     // (1) For classification, re-index classes.
     val labelColName = if (algo == "classification") "indexedLabel" else "label"
     if (algo == "classification") {
-      val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
+      val labelIndexer = new StringIndexer()
+        .setInputCol("labelString")
+        .setOutputCol(labelColName)
       stages += labelIndexer
     }
     // (2) Identify categorical features using VectorIndexer.
     //     Features with more than maxCategories values will be treated as continuous.
-    val featuresIndexer = new VectorIndexer().setInputCol("features")
-      .setOutputCol("indexedFeatures").setMaxCategories(10)
+    val featuresIndexer = new VectorIndexer()
+      .setInputCol("features")
+      .setOutputCol("indexedFeatures")
+      .setMaxCategories(10)
     stages += featuresIndexer
     // (3) Learn DecisionTree
     val dt = algo match {
       case "classification" =>
-        new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
+        new DecisionTreeClassifier()
+          .setFeaturesCol("indexedFeatures")
           .setLabelCol(labelColName)
           .setMaxDepth(params.maxDepth)
           .setMaxBins(params.maxBins)
@@ -242,7 +254,8 @@ object DecisionTreeExample {
           .setCacheNodeIds(params.cacheNodeIds)
           .setCheckpointInterval(params.checkpointInterval)
       case "regression" =>
-        new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
+        new DecisionTreeRegressor()
+          .setFeaturesCol("indexedFeatures")
           .setLabelCol(labelColName)
           .setMaxDepth(params.maxDepth)
           .setMaxBins(params.maxBins)

http://git-wip-us.apache.org/repos/asf/spark/blob/607eff0e/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
index 6f4509f..eb2609f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
   def setMaxDepth(value: Int): this.type = {
     require(value >= 0, s"maxDepth parameter must be >= 0.  Given bad value: $value")
     set(maxDepth, value)
-    this.asInstanceOf[this.type]
+    this
   }
 
   /** @group getParam */
@@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params {
   def getImpurity: String = getOrDefault(impurity)
 
   /** Convert new impurity to old impurity. */
-  protected def getOldImpurity: OldImpurity = {
+  private[ml] def getOldImpurity: OldImpurity = {
     getImpurity match {
       case "variance" => OldVariance
       case _ =>

http://git-wip-us.apache.org/repos/asf/spark/blob/607eff0e/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index cb940f6..708c769 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -38,7 +38,7 @@ sealed trait Split extends Serializable {
   private[tree] def toOld: OldSplit
 }
 
-private[ml] object Split {
+private[tree] object Split {
 
   def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
     oldSplit.featureType match {
@@ -58,7 +58,7 @@ private[ml] object Split {
  *                        left. Otherwise, it goes right.
  * @param numCategories  Number of categories for this feature.
  */
-final class CategoricalSplit(
+final class CategoricalSplit private[ml] (
     override val featureIndex: Int,
     leftCategories: Array[Double],
     private val numCategories: Int)
@@ -130,7 +130,8 @@ final class CategoricalSplit(
  * @param threshold  If the feature value is <= this threshold, then the split goes left.
  *                    Otherwise, it goes right.
  */
-final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
+final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
+  extends Split {
 
   override private[ml] def shouldGoLeft(features: Vector): Boolean = {
     features(featureIndex) <= threshold


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org