You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/07/09 04:17:36 UTC
git commit: [SPARK-2152][MLlib] fix bin offset in DecisionTree node
aggregations (also resolves SPARK-2160)
Repository: spark
Updated Branches:
refs/heads/master ac9cdc116 -> 1114207cc
[SPARK-2152][MLlib] fix bin offset in DecisionTree node aggregations (also resolves SPARK-2160)
Hi, this pull fixes (what I believe to be) a bug in DecisionTree.scala.
In the extractLeftRightNodeAggregates function, the first set of rightNodeAgg values for Regression are set in line 792 as follows:
rightNodeAgg(featureIndex)(2 * (numBins - 2))
= binData(shift + (2 * numBins - 1)))
Then there is a loop that sets the rest of the values, as in line 809:
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
binData(shift + (2 *(numBins - 2 - splitIndex))) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
But since splitIndex starts at 1, this ends up skipping a set of binData values.
The changes here address this issue, for both the Regression and Classification cases.
Author: johnnywalleye <js...@gmail.com>
Closes #1316 from johnnywalleye/master and squashes the following commits:
73809da [johnnywalleye] fix bin offset in DecisionTree node aggregations
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1114207c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1114207c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1114207c
Branch: refs/heads/master
Commit: 1114207cc8e4ef94cb97bbd5a2ef3ae4d51f73fa
Parents: ac9cdc1
Author: johnnywalleye <js...@gmail.com>
Authored: Tue Jul 8 19:17:26 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Jul 8 19:17:26 2014 -0700
----------------------------------------------------------------------
.../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/1114207c/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 3b13e52..74d5d7b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -807,10 +807,10 @@ object DecisionTree extends Serializable with Logging {
// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
- binData(shift + (2 *(numBins - 2 - splitIndex))) +
+ binData(shift + (2 *(numBins - 1 - splitIndex))) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
- binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
+ binData(shift + (2* (numBins - 1 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
splitIndex += 1
@@ -855,13 +855,13 @@ object DecisionTree extends Serializable with Logging {
// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
- binData(shift + (3 * (numBins - 2 - splitIndex))) +
+ binData(shift + (3 * (numBins - 1 - splitIndex))) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
- binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
+ binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
- binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
+ binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
splitIndex += 1