Hi Conrad,
2012/7/30 Conrad Lee <[email protected]>:
> Two of the most important parameters of the gradient boosting classifier are
> the learn_rate and n_estimators. In order to set these, the documentation
> states:
>
>> [HTF2009] recommend to set the learning rate to a small constant (e.g.
>> learn_rate <= 0.1) and choose n_estimators by early stopping
>
>
> Although the documentation for this classifier is in general good (thanks!),
> I didn't see how to perform this early stopping. The examples do makes do
> make it clear how I can first fit the classifier for a large n_estimators
> value, and subsequently look back and see if I would have gotten better
> results with fewer trees. However, that's rather inefficient - to be
> efficient, I'd like to stop fitting additional trees as soon as the accuracy
> stops improving substantially.
You're correct - currently this is the only way to perform "early
stopping". The original pull-request supported an additional parameter
``monitor`` which was called after each iteration using the current
state of the model and allowed proper early stopping, however, we
removed the parameter because we could not agree on the API.
>
> I suppose it's something like the following:
>
> X_train, X_test = X[:2000], X[2000:]
> y_train, y_test = y[:2000], y[2000:]
>
>
>
> n_additional_trees = [10, 90, 900, 9000, 90000]
> clf = ensemble.GradientBoostingClassifier(learn_rate=0.005, n_estimators =
> n_additional_trees.pop(0), subsample=0.5)
>
>
> clf.fit(X_train, X_test)
>
> previous_error = 1.0
> current_error = clf.loss_(y_test, y_pred)
> while (previous_error - current_error) > 0.01:
> previous_error = current_error
> for additional_tree in range(n_additional_trees.pop(0)):
>
>
> clf.fit_stage(UNDOCUMENTED_MYSTERY_PARAMS)
> current_error = clf.loss(y_test, y_pred)
>
>
>
>
> What I want the above code to do is run the boosting classifier first with
> 10 trees, then with 100, 1000, and 10000 trees. It will stop at any of
> these breakpoints if the improvement in accuracy is less than 0.01. Is this
> code basically correct - i.e., is this what is meant by "early stopping"?
The problem with the above code is that ``fit`` will clear the model
(i.e. ``clf.estimators_ = []``) and fit from scratch - this is rather
inefficient because it will perform a lot of redundant computations.
> One problem is that the parameters for the fit_stage method are not
> documented (here I just used the placeholder UNDOCUMENTED_MYSTERY_PARAMS).
``fit_stage`` is an internal method and not intended for "public" use
- I'll add a scope guard to make this explicit.
> I'll have a closer look at the source code to try to figure out what belongs
> here, but ideally this method would have better documentation.
To sum up:
Currently, the recommended way to determine the "optimal" number of
estimators is by fitting a large number of trees (e.g. 10000) and then
use held-out data along with ``predict_stage`` to select the "optimal"
number of estimators (= stages = iterations)::
clf.fit(X_train, y_train)
val_scores = np.empty((X_val.shape[0],), dtype=np.float64)
for i, y_pred in enumerate(clf.predict_stages(X_val)):
val_scores[i] = score(y_val, y_pred)
best_iter = np.argmax(val_scores) + 1
best_score = np.max(val_scores)
# set the model to the optimal number of estimators
clf.estimators_ = clf.estimators_[:best_iter, :]
clf.n_estimators = best_iter
I intend to add a ``warm_start`` parameter to ``BaseGradientBoosting``
which allows to re-fit an already fitted model w/o starting from
scratch (similar to GBMs ``more``) and a convenience method to
determine the "optimal" number of estimators based on held-out data,
CV, or OOB (similar to GBMs ``optimal``) in the near future.
best,
Peter
--
Peter Prettenhofer
------------------------------------------------------------------------------
Live Security Virtual Conference
Exclusive live event will cover all the ways today's security and
threat landscape has changed and how IT managers can respond. Discussions
will include endpoint security, mobile security and the latest in malware
threats. http://www.accelacomm.com/jaw/sfrnl04242012/114/50122263/
_______________________________________________
Scikit-learn-general mailing list
[email protected]
https://lists.sourceforge.net/lists/listinfo/scikit-learn-general