You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/10/01 10:03:32 UTC

git commit: [SPARK-3751] [mllib] DecisionTree: example update + print options

Repository: spark
Updated Branches:
  refs/heads/master eb43043f4 -> 7bf6cc970


[SPARK-3751] [mllib] DecisionTree: example update + print options

DecisionTreeRunner functionality additions:
* Allow user to pass in a test dataset
* Do not print full model if the model is too large.

As part of this, modify DecisionTreeModel and RandomForestModel to allow printing less info.  Proposed updates:
* toString: prints model summary
* toDebugString: prints full model (named after RDD.toDebugString)

Similar update to Python API:
* __repr__() now prints a model summary
* toDebugString() now prints the full model

CC: mengxr  chouqin manishamde codedeft  Small update (whomever can take a look).  Thanks!

Author: Joseph K. Bradley <jo...@gmail.com>

Closes #2604 from jkbradley/dtrunner-update and squashes the following commits:

b2b3c60 [Joseph K. Bradley] re-added python sql doc test, temporarily removed before
07b1fae [Joseph K. Bradley] repr() now prints a model summary toDebugString() now prints the full model
1d0d93d [Joseph K. Bradley] Updated DT and RF to print less when toString is called. Added toDebugString for verbose printing.
22eac8c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update
e007a95 [Joseph K. Bradley] Updated DecisionTreeRunner to accept a test dataset.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7bf6cc97
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7bf6cc97
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7bf6cc97

Branch: refs/heads/master
Commit: 7bf6cc9701cbb0f77fb85a412e387fb92274fca5
Parents: eb43043
Author: Joseph K. Bradley <jo...@gmail.com>
Authored: Wed Oct 1 01:03:24 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Oct 1 01:03:24 2014 -0700

----------------------------------------------------------------------
 .../examples/mllib/DecisionTreeRunner.scala     | 99 ++++++++++++++------
 .../mllib/tree/model/DecisionTreeModel.scala    | 14 ++-
 .../mllib/tree/model/RandomForestModel.scala    | 30 ++++--
 python/pyspark/mllib/tree.py                    | 10 +-
 4 files changed, 111 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7bf6cc97/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 96fb068..4adc91d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -52,6 +52,7 @@ object DecisionTreeRunner {
 
   case class Params(
       input: String = null,
+      testInput: String = "",
       dataFormat: String = "libsvm",
       algo: Algo = Classification,
       maxDepth: Int = 5,
@@ -98,13 +99,18 @@ object DecisionTreeRunner {
           s"default: ${defaultParams.featureSubsetStrategy}")
         .action((x, c) => c.copy(featureSubsetStrategy = x))
       opt[Double]("fracTest")
-        .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
+        .text(s"fraction of data to hold out for testing.  If given option testInput, " +
+          s"this option is ignored. default: ${defaultParams.fracTest}")
         .action((x, c) => c.copy(fracTest = x))
+      opt[String]("testInput")
+        .text(s"input path to test dataset.  If given, option fracTest is ignored." +
+          s" default: ${defaultParams.testInput}")
+        .action((x, c) => c.copy(testInput = x))
       opt[String]("<dataFormat>")
         .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
         .action((x, c) => c.copy(dataFormat = x))
       arg[String]("<input>")
-        .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
+        .text("input path to labeled examples")
         .required()
         .action((x, c) => c.copy(input = x))
       checkConfig { params =>
@@ -141,7 +147,7 @@ object DecisionTreeRunner {
       case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
     }
     // For classification, re-index classes if needed.
-    val (examples, numClasses) = params.algo match {
+    val (examples, classIndexMap, numClasses) = params.algo match {
       case Classification => {
         // classCounts: class --> # examples in class
         val classCounts = origExamples.map(_.label).countByValue()
@@ -170,16 +176,40 @@ object DecisionTreeRunner {
           val frac = classCounts(c) / numExamples.toDouble
           println(s"$c\t$frac\t${classCounts(c)}")
         }
-        (examples, numClasses)
+        (examples, classIndexMap, numClasses)
       }
       case Regression =>
-        (origExamples, 0)
+        (origExamples, null, 0)
       case _ =>
         throw new IllegalArgumentException("Algo ${params.algo} not supported.")
     }
 
-    // Split into training, test.
-    val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+    // Create training, test sets.
+    val splits = if (params.testInput != "") {
+      // Load testInput.
+      val origTestExamples = params.dataFormat match {
+        case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
+        case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput)
+      }
+      params.algo match {
+        case Classification => {
+          // classCounts: class --> # examples in class
+          val testExamples = {
+            if (classIndexMap.isEmpty) {
+              origTestExamples
+            } else {
+              origTestExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features))
+            }
+          }
+          Array(examples, testExamples)
+        }
+        case Regression =>
+          Array(examples, origTestExamples)
+      }
+    } else {
+      // Split input into training, test.
+      examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+    }
     val training = splits(0).cache()
     val test = splits(1).cache()
     val numTraining = training.count()
@@ -206,32 +236,56 @@ object DecisionTreeRunner {
           minInfoGain = params.minInfoGain)
     if (params.numTrees == 1) {
       val model = DecisionTree.train(training, strategy)
-      println(model)
+      if (model.numNodes < 20) {
+        println(model.toDebugString) // Print full model.
+      } else {
+        println(model) // Print model summary.
+      }
       if (params.algo == Classification) {
-        val accuracy =
+        val trainAccuracy =
+          new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+            .precision
+        println(s"Train accuracy = $trainAccuracy")
+        val testAccuracy =
           new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
-        println(s"Test accuracy = $accuracy")
+        println(s"Test accuracy = $testAccuracy")
       }
       if (params.algo == Regression) {
-        val mse = meanSquaredError(model, test)
-        println(s"Test mean squared error = $mse")
+        val trainMSE = meanSquaredError(model, training)
+        println(s"Train mean squared error = $trainMSE")
+        val testMSE = meanSquaredError(model, test)
+        println(s"Test mean squared error = $testMSE")
       }
     } else {
       val randomSeed = Utils.random.nextInt()
       if (params.algo == Classification) {
         val model = RandomForest.trainClassifier(training, strategy, params.numTrees,
           params.featureSubsetStrategy, randomSeed)
-        println(model)
-        val accuracy =
+        if (model.totalNumNodes < 30) {
+          println(model.toDebugString) // Print full model.
+        } else {
+          println(model) // Print model summary.
+        }
+        val trainAccuracy =
+          new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+            .precision
+        println(s"Train accuracy = $trainAccuracy")
+        val testAccuracy =
           new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
-        println(s"Test accuracy = $accuracy")
+        println(s"Test accuracy = $testAccuracy")
       }
       if (params.algo == Regression) {
         val model = RandomForest.trainRegressor(training, strategy, params.numTrees,
           params.featureSubsetStrategy, randomSeed)
-        println(model)
-        val mse = meanSquaredError(model, test)
-        println(s"Test mean squared error = $mse")
+        if (model.totalNumNodes < 30) {
+          println(model.toDebugString) // Print full model.
+        } else {
+          println(model) // Print model summary.
+        }
+        val trainMSE = meanSquaredError(model, training)
+        println(s"Train mean squared error = $trainMSE")
+        val testMSE = meanSquaredError(model, test)
+        println(s"Test mean squared error = $testMSE")
       }
     }
 
@@ -239,15 +293,6 @@ object DecisionTreeRunner {
   }
 
   /**
-   * Calculates the classifier accuracy.
-   */
-  private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
-    val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
-    val count = data.count()
-    correctCount.toDouble / count
-  }
-
-  /**
    * Calculates the mean squared error for regression.
    */
   private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {

http://git-wip-us.apache.org/repos/asf/spark/blob/7bf6cc97/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 271b2c4..ec1d99a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -68,15 +68,23 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
   }
 
   /**
-   * Print full model.
+   * Print a summary of the model.
    */
   override def toString: String = algo match {
     case Classification =>
-      s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2)
+      s"DecisionTreeModel classifier of depth $depth with $numNodes nodes"
     case Regression =>
-      s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2)
+      s"DecisionTreeModel regressor of depth $depth with $numNodes nodes"
     case _ => throw new IllegalArgumentException(
       s"DecisionTreeModel given unknown algo parameter: $algo.")
   }
 
+  /**
+   * Print the full model to a string.
+   */
+  def toDebugString: String = {
+    val header = toString + "\n"
+    header + topNode.subtreeToString(2)
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7bf6cc97/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
index 538c0e2..4d66d6d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
@@ -73,17 +73,27 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext
   def numTrees: Int = trees.size
 
   /**
-   * Print full model.
+   * Get total number of nodes, summed over all trees in the forest.
    */
-  override def toString: String = {
-    val header = algo match {
-      case Classification =>
-        s"RandomForestModel classifier with $numTrees trees\n"
-      case Regression =>
-        s"RandomForestModel regressor with $numTrees trees\n"
-      case _ => throw new IllegalArgumentException(
-        s"RandomForestModel given unknown algo parameter: $algo.")
-    }
+  def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum
+
+  /**
+   * Print a summary of the model.
+   */
+  override def toString: String = algo match {
+    case Classification =>
+      s"RandomForestModel classifier with $numTrees trees"
+    case Regression =>
+      s"RandomForestModel regressor with $numTrees trees"
+    case _ => throw new IllegalArgumentException(
+      s"RandomForestModel given unknown algo parameter: $algo.")
+  }
+
+  /**
+   * Print the full model to a string.
+   */
+  def toDebugString: String = {
+    val header = toString + "\n"
     header + trees.zipWithIndex.map { case (tree, treeIndex) =>
       s"  Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
     }.fold("")(_ + _)

http://git-wip-us.apache.org/repos/asf/spark/blob/7bf6cc97/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index f59a818..afdcdbd 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -77,8 +77,13 @@ class DecisionTreeModel(object):
         return self._java_model.depth()
 
     def __repr__(self):
+        """ Print summary of model. """
         return self._java_model.toString()
 
+    def toDebugString(self):
+        """ Print full model. """
+        return self._java_model.toDebugString()
+
 
 class DecisionTree(object):
 
@@ -135,7 +140,6 @@ class DecisionTree(object):
         >>> from numpy import array
         >>> from pyspark.mllib.regression import LabeledPoint
         >>> from pyspark.mllib.tree import DecisionTree
-        >>> from pyspark.mllib.linalg import SparseVector
         >>>
         >>> data = [
         ...     LabeledPoint(0.0, [0.0]),
@@ -145,7 +149,9 @@ class DecisionTree(object):
         ... ]
         >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
         >>> print model,  # it already has newline
-        DecisionTreeModel classifier
+        DecisionTreeModel classifier of depth 1 with 3 nodes
+        >>> print model.toDebugString(),  # it already has newline
+        DecisionTreeModel classifier of depth 1 with 3 nodes
           If (feature 0 <= 0.5)
            Predict: 0.0
           Else (feature 0 > 0.5)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org