GitHub user sethah opened a pull request:

    https://github.com/apache/spark/pull/16722

    [SPARK-9478][ML][MLlib] Add sample weights to decision trees

    ## What changes were proposed in this pull request?
    
    This patch adds support for sample weights to `DecisionTreeRegressor` and 
`DecisionTreeClassifier`. 
    
    *Note:* This patch does not add support for sample weights to RandomForest. 
As discussed in the JIRA, we would like to add sample weights into the bagging 
process. This patch is large enough as is, and there are some additional 
considerations to be made for random forests. Since the machinery introduced 
here needs to be present regardless, I have opted to leave random forests for a 
follow up pr. 
    
    ## How was this patch tested?
    
    The algorithms are tested to ensure that:
    1. Arbitrary scaling of constant weights has no effect
    2. Outliers with small weights do not affect the learned model
    3. Oversampling and weighting are equivalent
    
    Unit tests are also added to test other smaller components.
    
    ## Summary of changes
    
    * Impurity aggregators now store weighted sufficient statistics. They also 
store a raw count, however, since this is needed to use `minInstancesPerNode`. 
    
    * Impurity aggregators now also hold the raw count.
    
    * This patch maintains the meaning of `minInstancesPerNode`, in that the 
parameter still corresponds to raw, unweighted counts. It also adds a new 
parameter `minWeightFractionPerNode` which requires that nodes must contain at 
least `minWeightFractionPerNode * weightedNumExamples` total weight.
    
    * This patch modifies `findSplitsForContinuousFeatures` to use weighted 
sums. Unit tests are added.
    
    * TreePoint is modified to hold a sample weight
    
    * BaggedPoint is modified from:
    
    ````scala
    private[spark] class BaggedPoint[Datum](val datum: Datum, val 
subsampleWeights: Array[Double]) extends Serializable
    ````
    
    to 
    
    ````scala
    private[spark] class BaggedPoint[Datum](
        val datum: Datum,
        val subsampleCounts: Array[Int],
        val sampleWeight: Double) extends Serializable
    ````
    
    We do not simply multiply the counts by the weight and store that because 
we need the raw counts and the weight in order to use both 
`minInstancesPerNode` and `minWeightPerNode`
    
    *Note:* many of the changed files are due simply to using `Instance` 
instead of `LabeledPoint`
    


You can merge this pull request into a Git repository by running:

    $ git pull https://github.com/sethah/spark SPARK-9478-tree

Alternatively you can review and apply these changes as the patch at:

    https://github.com/apache/spark/pull/16722.patch

To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:

    This closes #16722
    
----
commit 2d86cea640634a205e378bddee0b01780d019ea2
Author: sethah <[email protected]>
Date:   2017-01-27T16:38:36Z

    add weights to dt

commit 7dc1437df21999554e42d35d1d544839074414cf
Author: sethah <[email protected]>
Date:   2017-01-27T20:34:24Z

    dt tests passing

----


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to