Repository: spark Updated Branches: refs/heads/branch-1.0 885489112 -> d569838bc
[SPARK-2152][MLlib] fix bin offset in DecisionTree node aggregations (also resolves SPARK-2160) Hi, this pull fixes (what I believe to be) a bug in DecisionTree.scala. In the extractLeftRightNodeAggregates function, the first set of rightNodeAgg values for Regression are set in line 792 as follows: rightNodeAgg(featureIndex)(2 * (numBins - 2)) = binData(shift + (2 * numBins - 1))) Then there is a loop that sets the rest of the values, as in line 809: rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = binData(shift + (2 *(numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) But since splitIndex starts at 1, this ends up skipping a set of binData values. The changes here address this issue, for both the Regression and Classification cases. Author: johnnywalleye <[email protected]> Closes #1316 from johnnywalleye/master and squashes the following commits: 73809da [johnnywalleye] fix bin offset in DecisionTree node aggregations (cherry picked from commit 1114207cc8e4ef94cb97bbd5a2ef3ae4d51f73fa) Signed-off-by: Xiangrui Meng <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d569838b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d569838b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d569838b Branch: refs/heads/branch-1.0 Commit: d569838bc067f2b64f6c10e54ba8e5973f8fc93a Parents: 8854891 Author: johnnywalleye <[email protected]> Authored: Tue Jul 8 19:17:26 2014 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Tue Jul 8 19:17:43 2014 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d569838b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3b13e52..74d5d7b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -807,10 +807,10 @@ object DecisionTree extends Serializable with Logging { // calculating right node aggregate for a split as a sum of right node aggregate of a // higher split and the right bin aggregate of a bin where the split is a low split rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = - binData(shift + (2 *(numBins - 2 - splitIndex))) + + binData(shift + (2 *(numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = - binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + + binData(shift + (2* (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) splitIndex += 1 @@ -855,13 +855,13 @@ object DecisionTree extends Serializable with Logging { // calculating right node aggregate for a split as a sum of right node aggregate of a // higher split and the right bin aggregate of a bin where the split is a low split rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = - binData(shift + (3 * (numBins - 2 - splitIndex))) + + binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = - binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = - binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) splitIndex += 1
