GitHub user sethah opened a pull request:
https://github.com/apache/spark/pull/16722
[SPARK-9478][ML][MLlib] Add sample weights to decision trees
## What changes were proposed in this pull request?
This patch adds support for sample weights to `DecisionTreeRegressor` and
`DecisionTreeClassifier`.
*Note:* This patch does not add support for sample weights to RandomForest.
As discussed in the JIRA, we would like to add sample weights into the bagging
process. This patch is large enough as is, and there are some additional
considerations to be made for random forests. Since the machinery introduced
here needs to be present regardless, I have opted to leave random forests for a
follow up pr.
## How was this patch tested?
The algorithms are tested to ensure that:
1. Arbitrary scaling of constant weights has no effect
2. Outliers with small weights do not affect the learned model
3. Oversampling and weighting are equivalent
Unit tests are also added to test other smaller components.
## Summary of changes
* Impurity aggregators now store weighted sufficient statistics. They also
store a raw count, however, since this is needed to use `minInstancesPerNode`.
* Impurity aggregators now also hold the raw count.
* This patch maintains the meaning of `minInstancesPerNode`, in that the
parameter still corresponds to raw, unweighted counts. It also adds a new
parameter `minWeightFractionPerNode` which requires that nodes must contain at
least `minWeightFractionPerNode * weightedNumExamples` total weight.
* This patch modifies `findSplitsForContinuousFeatures` to use weighted
sums. Unit tests are added.
* TreePoint is modified to hold a sample weight
* BaggedPoint is modified from:
````scala
private[spark] class BaggedPoint[Datum](val datum: Datum, val
subsampleWeights: Array[Double]) extends Serializable
````
to
````scala
private[spark] class BaggedPoint[Datum](
val datum: Datum,
val subsampleCounts: Array[Int],
val sampleWeight: Double) extends Serializable
````
We do not simply multiply the counts by the weight and store that because
we need the raw counts and the weight in order to use both
`minInstancesPerNode` and `minWeightPerNode`
*Note:* many of the changed files are due simply to using `Instance`
instead of `LabeledPoint`
You can merge this pull request into a Git repository by running:
$ git pull https://github.com/sethah/spark SPARK-9478-tree
Alternatively you can review and apply these changes as the patch at:
https://github.com/apache/spark/pull/16722.patch
To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:
This closes #16722
----
commit 2d86cea640634a205e378bddee0b01780d019ea2
Author: sethah <[email protected]>
Date: 2017-01-27T16:38:36Z
add weights to dt
commit 7dc1437df21999554e42d35d1d544839074414cf
Author: sethah <[email protected]>
Date: 2017-01-27T20:34:24Z
dt tests passing
----
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]