imatiach-msft commented on a change in pull request #21632:
[SPARK-19591][ML][MLlib] Add sample weights to decision trees
URL: https://github.com/apache/spark/pull/21632#discussion_r247175614
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
##########
@@ -1002,19 +1019,20 @@ private[spark] object RandomForest extends Logging
with Serializable {
val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value except zero value
- val partNumSamples = featureSamples.size
- val partValueCountMap = scala.collection.mutable.Map[Double, Int]()
- featureSamples.foreach { x =>
- partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1
- }
+ val (partValueCountMap, partNumSamples) =
+ featureSamples.foldLeft((Map.empty[Double, Double], 0.0)) {
+ case ((m, cnt), (w, x)) =>
+ (m + ((x, m.getOrElse(x, 0.0) + w)), cnt + w)
+ }
// Calculate the expected number of samples for finding splits
- val numSamples = (samplesFractionForFindSplits(metadata) *
metadata.numExamples).toInt
+ val weightedNumSamples = samplesFractionForFindSplits(metadata) *
+ metadata.weightedNumExamples
// add expected zero value count and get complete statistics
- val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples >
0) {
- partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples))
+ val valueCountMap: Map[Double, Double] = if (weightedNumSamples -
partNumSamples > 1e-5) {
Review comment:
ok, so, the tests all pass when I do this but if I put a print here like
this:
```
val valueCountMap = if (weightedNumSamples - partNumSamples > Utils.EPSILON)
{
println("adding zero weight: " + (weightedNumSamples -
partNumSamples))
partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples))
}
```
and run the sample weights test:
> testOnly org.apache.spark.ml.regression.DecisionTreeRegressorSuite -- -z
"sample weights"
I get the output:
```
adding zero weight: 4.440892098500626E-14
adding zero weight: 4.440892098500626E-14
adding zero weight: 1.432454155292362E-11
adding zero weight: 1.432454155292362E-11
adding zero weight: 4.440892098500626E-14
adding zero weight: 4.440892098500626E-14
adding zero weight: 1.432454155292362E-11
adding zero weight: 1.432454155292362E-11
```
We really should be ignoring those - and as you can see for most of them the
weight is around 1e-11 to 1e-14
The problem is that we are adding a lot of doubles for partNumSamples, and
since we are adding so many of them together the precision isn't very good due
to the limits of floating point representation. For example, if you print out
weightedNumSamples and partNumSamples you will see something similar to this
for the 1000 weight case:
weightedNumSamples: 1000.1999999999784
partnumsamples: 1000.1999999999641
Utils.Epsilon would usually be fine but since we are adding so many of them
together it seems to be too strict
I think I've come up with a good formula though:
val tolerance = Utils.EPSILON * unweightedNumSamples * unweightedNumSamples
this seems to usually work well enough and customizes the tolerance based on
the number of example summed
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]