You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/10/20 22:12:38 UTC

git commit: [SPARK-3207][MLLIB]Choose splits for continuous features in DecisionTree more adaptively

Repository: spark
Updated Branches:
  refs/heads/master 4afe9a485 -> eadc4c590


[SPARK-3207][MLLIB]Choose splits for continuous features in DecisionTree more adaptively

DecisionTree splits on continuous features by choosing an array of values from a subsample of the data.
Currently, it does not check for identical values in the subsample, so it could end up having multiple copies of the same split. In this PR, we choose splits for a continuous feature in 3 steps:

1. Sort sample values for this feature
2. Get number of occurrence of each distinct value
3. Iterate the value count array computed in step 2 to choose splits.

After find splits, `numSplits` and `numBins` in metadata will be updated.

CC: mengxr manishamde jkbradley, please help me review this, thanks.

Author: Qiping Li <li...@gmail.com>
Author: chouqin <li...@gmail.com>
Author: liqi <li...@gmail.com>
Author: qiping.lqp <qi...@alibaba-inc.com>

Closes #2780 from chouqin/dt-findsplits and squashes the following commits:

18d0301 [Qiping Li] check explicitly findsplits return distinct splits
8dc28ab [chouqin] remove blank lines
ffc920f [chouqin] adjust code based on comments and add more test cases
9857039 [chouqin] Merge branch 'master' of https://github.com/apache/spark into dt-findsplits
d353596 [qiping.lqp] fix pyspark doc test
9e64699 [Qiping Li] fix random forest unit test
3c72913 [Qiping Li] fix random forest unit test
092efcb [Qiping Li] fix bug
f69f47f [Qiping Li] fix bug
ab303a4 [Qiping Li] fix bug
af6dc97 [Qiping Li] fix bug
2a8267a [Qiping Li] fix bug
c339a61 [Qiping Li] fix bug
369f812 [Qiping Li] fix style
8f46af6 [Qiping Li] add comments and unit test
9e7138e [Qiping Li] Merge branch 'dt-findsplits' of https://github.com/chouqin/spark into dt-findsplits
1b25a35 [Qiping Li] Merge branch 'master' of https://github.com/apache/spark into dt-findsplits
0cd744a [liqi] fix bug
3652823 [Qiping Li] fix bug
af7cb79 [Qiping Li] Choose splits for continuous features in DecisionTree more adaptively


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eadc4c59
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eadc4c59
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eadc4c59

Branch: refs/heads/master
Commit: eadc4c590ee43572528da55d84ed65f09153e857
Parents: 4afe9a4
Author: Qiping Li <li...@gmail.com>
Authored: Mon Oct 20 13:12:26 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Oct 20 13:12:26 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  | 104 +++++++++++++++++--
 .../mllib/tree/impl/DecisionTreeMetadata.scala  |  11 ++
 .../spark/mllib/tree/DecisionTreeSuite.scala    |  68 +++++++++++-
 .../spark/mllib/tree/RandomForestSuite.scala    |   5 +-
 python/pyspark/mllib/tree.py                    |   4 +-
 5 files changed, 176 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eadc4c59/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 03eeaa7..6737a2f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
@@ -909,32 +911,39 @@ object DecisionTree extends Serializable with Logging {
         // Iterate over all features.
         var featureIndex = 0
         while (featureIndex < numFeatures) {
-          val numSplits = metadata.numSplits(featureIndex)
-          val numBins = metadata.numBins(featureIndex)
           if (metadata.isContinuous(featureIndex)) {
-            val numSamples = sampledInput.length
+            val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
+            val featureSplits = findSplitsForContinuousFeature(featureSamples,
+              metadata, featureIndex)
+
+            val numSplits = featureSplits.length
+            val numBins = numSplits + 1
+            logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
             splits(featureIndex) = new Array[Split](numSplits)
             bins(featureIndex) = new Array[Bin](numBins)
-            val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
-            val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex)
-            logDebug("stride = " + stride)
-            for (splitIndex <- 0 until numSplits) {
-              val sampleIndex = splitIndex * stride.toInt
-              // Set threshold halfway in between 2 samples.
-              val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0
+
+            var splitIndex = 0
+            while (splitIndex < numSplits) {
+              val threshold = featureSplits(splitIndex)
               splits(featureIndex)(splitIndex) =
                 new Split(featureIndex, threshold, Continuous, List())
+              splitIndex += 1
             }
             bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
               splits(featureIndex)(0), Continuous, Double.MinValue)
-            for (splitIndex <- 1 until numSplits) {
+
+            splitIndex = 1
+            while (splitIndex < numSplits) {
               bins(featureIndex)(splitIndex) =
                 new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
                   Continuous, Double.MinValue)
+              splitIndex += 1
             }
             bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
               new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
           } else {
+            val numSplits = metadata.numSplits(featureIndex)
+            val numBins = metadata.numBins(featureIndex)
             // Categorical feature
             val featureArity = metadata.featureArity(featureIndex)
             if (metadata.isUnordered(featureIndex)) {
@@ -1011,4 +1020,77 @@ object DecisionTree extends Serializable with Logging {
     categories
   }
 
+  /**
+   * Find splits for a continuous feature
+   * NOTE: Returned number of splits is set based on `featureSamples` and
+   *       could be different from the specified `numSplits`.
+   *       The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+   * @param featureSamples feature values of each sample
+   * @param metadata decision tree metadata
+   *                 NOTE: `metadata.numbins` will be changed accordingly
+   *                       if there are not enough splits to be found
+   * @param featureIndex feature index to find splits
+   * @return array of splits
+   */
+  private[tree] def findSplitsForContinuousFeature(
+      featureSamples: Array[Double],
+      metadata: DecisionTreeMetadata,
+      featureIndex: Int): Array[Double] = {
+    require(metadata.isContinuous(featureIndex),
+      "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+    val splits = {
+      val numSplits = metadata.numSplits(featureIndex)
+
+      // get count for each distinct value
+      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
+        m + ((x, m.getOrElse(x, 0) + 1))
+      }
+      // sort distinct values
+      val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+      // if possible splits is not enough or just enough, just return all possible splits
+      val possibleSplits = valueCounts.length
+      if (possibleSplits <= numSplits) {
+        valueCounts.map(_._1)
+      } else {
+        // stride between splits
+        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+        logDebug("stride = " + stride)
+
+        // iterate `valueCount` to find splits
+        val splits = new ArrayBuffer[Double]
+        var index = 1
+        // currentCount: sum of counts of values that have been visited
+        var currentCount = valueCounts(0)._2
+        // targetCount: target value for `currentCount`.
+        // If `currentCount` is closest value to `targetCount`,
+        // then current value is a split threshold.
+        // After finding a split threshold, `targetCount` is added by stride.
+        var targetCount = stride
+        while (index < valueCounts.length) {
+          val previousCount = currentCount
+          currentCount += valueCounts(index)._2
+          val previousGap = math.abs(previousCount - targetCount)
+          val currentGap = math.abs(currentCount - targetCount)
+          // If adding count of current value to currentCount
+          // makes the gap between currentCount and targetCount smaller,
+          // previous value is a split threshold.
+          if (previousGap < currentGap) {
+            splits.append(valueCounts(index - 1)._1)
+            targetCount += stride
+          }
+          index += 1
+        }
+
+        splits.toArray
+      }
+    }
+
+    assert(splits.length > 0)
+    // set number of splits accordingly
+    metadata.setNumSplits(featureIndex, splits.length)
+
+    splits
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/eadc4c59/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 772c026..5bc0f26 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -76,6 +76,17 @@ private[tree] class DecisionTreeMetadata(
     numBins(featureIndex) - 1
   }
 
+
+  /**
+   * Set number of splits for a continuous feature.
+   * For a continuous feature, number of bins is number of splits plus 1.
+   */
+  def setNumSplits(featureIndex: Int, numSplits: Int) {
+    require(isContinuous(featureIndex),
+      s"Only number of bin for a continuous feature can be set.")
+    numBins(featureIndex) = numSplits + 1
+  }
+
   /**
    * Indicates if feature subsampling is being used.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/eadc4c59/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 98a72b0..8fc5e11 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
 import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
@@ -102,6 +102,72 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
   }
 
+  test("find splits for a continuous feature") {
+    // find splits for normal case
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(6), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array.fill(200000)(math.random)
+      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+      assert(splits.length === 5)
+      assert(fakeMetadata.numSplits(0) === 5)
+      assert(fakeMetadata.numBins(0) === 6)
+      // check returned splits are distinct
+      assert(splits.distinct.length === splits.length)
+    }
+
+    // find splits should not return identical splits
+    // when there are not enough split candidates, reduce the number of splits in metadata
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(5), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
+      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+      assert(splits.length === 3)
+      assert(fakeMetadata.numSplits(0) === 3)
+      assert(fakeMetadata.numBins(0) === 4)
+      // check returned splits are distinct
+      assert(splits.distinct.length === splits.length)
+    }
+
+    // find splits when most samples close to the minimum
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(3), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
+      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+      assert(splits.length === 2)
+      assert(fakeMetadata.numSplits(0) === 2)
+      assert(fakeMetadata.numBins(0) === 3)
+      assert(splits(0) === 2.0)
+      assert(splits(1) === 3.0)
+    }
+
+    // find splits when most samples close to the maximum
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(3), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
+      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+      assert(splits.length === 1)
+      assert(fakeMetadata.numSplits(0) === 1)
+      assert(fakeMetadata.numBins(0) === 2)
+      assert(splits(0) === 1.0)
+    }
+  }
+
   test("Multiclass classification with unordered categorical features:" +
       " split and bin calculations") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()

http://git-wip-us.apache.org/repos/asf/spark/blob/eadc4c59/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index fb44ceb..6b13765 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -93,8 +93,9 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
     val categoricalFeaturesInfo = Map.empty[Int, Int]
     val numTrees = 1
 
-    val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
-      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+    val strategy = new Strategy(algo = Regression, impurity = Variance,
+      maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+      categoricalFeaturesInfo = categoricalFeaturesInfo)
 
     val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
       featureSubsetStrategy = "auto", seed = 123)

http://git-wip-us.apache.org/repos/asf/spark/blob/eadc4c59/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 0938eeb..64ee79d 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -153,9 +153,9 @@ class DecisionTree(object):
         DecisionTreeModel classifier of depth 1 with 3 nodes
         >>> print model.toDebugString(),  # it already has newline
         DecisionTreeModel classifier of depth 1 with 3 nodes
-          If (feature 0 <= 0.5)
+          If (feature 0 <= 0.0)
            Predict: 0.0
-          Else (feature 0 > 0.5)
+          Else (feature 0 > 0.0)
            Predict: 1.0
         >>> model.predict(array([1.0])) > 0
         True


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org