Github user manishamde commented on a diff in the pull request:
https://github.com/apache/spark/pull/1720#discussion_r15722046
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
@@ -522,28 +522,36 @@ object DecisionTree extends Serializable with Logging
{
}
/**
- * Sequential search helper method to find bin for categorical
feature.
+ * Sequential search helper method to find bin for categorical
feature
+ * (for classification and regression).
*/
- def
sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = {
+ def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
val featureCategories =
strategy.categoricalFeaturesInfo(featureIndex)
- val numCategoricalBins = math.pow(2.0, featureCategories -
1).toInt - 1
+ val featureValue = labeledPoint.features(featureIndex)
var binIndex = 0
- while (binIndex < numCategoricalBins) {
+ while (binIndex < featureCategories) {
val bin = bins(featureIndex)(binIndex)
val categories = bin.highSplit.categories
- val features = labeledPoint.features
- if (categories.contains(features(featureIndex))) {
+ if (categories.contains(featureValue)) {
return binIndex
}
binIndex += 1
}
+ if (featureValue < 0 || featureValue >= featureCategories) {
--- End diff --
I like this idea. Do you think it might be better handled in a verification
step before training? The idea would be run the tree in '''diagnostic''' mode
so that such errors are caught before training.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---