You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/10/11 00:04:15 UTC
spark git commit: [SPARK-14610][ML] Remove superfluous split for
continuous features in decision tree training
Repository: spark
Updated Branches:
refs/heads/master 29f186bfd -> 03c40202f
[SPARK-14610][ML] Remove superfluous split for continuous features in decision tree training
## What changes were proposed in this pull request?
A nonsensical split is produced from method `findSplitsForContinuousFeature` for decision trees. This PR removes the superfluous split and updates unit tests accordingly. Additionally, an assertion to check that the number of found splits is `> 0` is removed, and instead features with zero possible splits are ignored.
## How was this patch tested?
A unit test was added to check that finding splits for a constant feature produces an empty array.
Author: sethah <se...@gmail.com>
Closes #12374 from sethah/SPARK-14610.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/03c40202
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/03c40202
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/03c40202
Branch: refs/heads/master
Commit: 03c40202f36ea9fc93071b79fed21ed3f2190ba1
Parents: 29f186b
Author: sethah <se...@gmail.com>
Authored: Mon Oct 10 17:04:11 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Oct 10 17:04:11 2016 -0700
----------------------------------------------------------------------
.../spark/ml/tree/impl/RandomForest.scala | 31 +++++++-------
.../spark/ml/tree/impl/RandomForestSuite.scala | 44 ++++++++++++++++----
2 files changed, 52 insertions(+), 23 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/03c40202/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 0b7ad92..b504f41 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -705,14 +705,17 @@ private[spark] object RandomForest extends Logging {
node.stats
}
+ val validFeatureSplits =
+ Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx =>
+ featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
+ .getOrElse((featureIndexIdx, featureIndexIdx))
+ }.withFilter { case (_, featureIndex) =>
+ binAggregates.metadata.numSplits(featureIndex) != 0
+ }
+
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
- Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
- val featureIndex = if (featuresForNode.nonEmpty) {
- featuresForNode.get.apply(featureIndexIdx)
- } else {
- featureIndexIdx
- }
+ validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
@@ -966,7 +969,7 @@ private[spark] object RandomForest extends Logging {
* NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found
* @param featureIndex feature index to find splits
- * @return array of splits
+ * @return array of split thresholds
*/
private[tree] def findSplitsForContinuousFeature(
featureSamples: Iterable[Double],
@@ -975,7 +978,9 @@ private[spark] object RandomForest extends Logging {
require(metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
- val splits = {
+ val splits = if (featureSamples.isEmpty) {
+ Array.empty[Double]
+ } else {
val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value
@@ -987,9 +992,9 @@ private[spark] object RandomForest extends Logging {
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
// if possible splits is not enough or just enough, just return all possible splits
- val possibleSplits = valueCounts.length
+ val possibleSplits = valueCounts.length - 1
if (possibleSplits <= numSplits) {
- valueCounts.map(_._1)
+ valueCounts.map(_._1).init
} else {
// stride between splits
val stride: Double = numSamples.toDouble / (numSplits + 1)
@@ -1023,12 +1028,6 @@ private[spark] object RandomForest extends Logging {
splitsBuilder.result()
}
}
-
- // TODO: Do not fail; just ignore the useless feature.
- assert(splits.length > 0,
- s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
- " Please remove this feature and then try again.")
-
splits
}
http://git-wip-us.apache.org/repos/asf/spark/blob/03c40202/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 79b19ea..499d386 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -115,7 +115,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 3)
+ assert(splits === Array(1.0, 2.0))
// check returned splits are distinct
assert(splits.distinct.length === splits.length)
}
@@ -129,23 +129,53 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
)
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 2)
- assert(splits(0) === 2.0)
- assert(splits(1) === 3.0)
+ assert(splits === Array(2.0, 3.0))
}
// find splits when most samples close to the maximum
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
- Array(3), Gini, QuantileStrategy.Sort,
+ Array(2), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 1)
- assert(splits(0) === 1.0)
+ assert(splits === Array(1.0))
}
+
+ // find splits for constant feature
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(3), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(0, 0, 0).map(_.toDouble)
+ val featureSamplesEmpty = Array.empty[Double]
+ val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits === Array[Double]())
+ val splitsEmpty =
+ RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
+ assert(splitsEmpty === Array[Double]())
+ }
+ }
+
+ test("train with constant features") {
+ val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
+ val data = Array.fill(5)(lp)
+ val rdd = sc.parallelize(data)
+ val strategy = new OldStrategy(
+ OldAlgo.Classification,
+ Gini,
+ maxDepth = 2,
+ numClasses = 2,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5))
+ val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
+ assert(tree.rootNode.impurity === -1.0)
+ assert(tree.depth === 0)
+ assert(tree.rootNode.prediction === lp.label)
}
test("Multiclass classification with unordered categorical features: split calculations") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org