Hello,
Version 0.16.1 adds warm_start to RandomForestClassifier, but the documentation
doesn't include a note that warm_start disables parallelization. I found a
reference to this in a comment in the "OOB Errors for Random Forests" example
in the development documentation.
http://scikit-learn.org/dev/auto_examples/ensemble/plot_ensemble_oob.html
Setting both does not generate a warning or error. My own testing indicates
that warm_start allows the use of n_jobs for version 0.16.1. I can see the
processor use in Task Manager.
I am using a Numpy 64bit experimental build with Mingw-w64 and OpenBlas
provided by Carl Kleffner
(https://bitbucket.org/carlkl/mingw-w64-for-python/downloads). I have a VM with
the out-of-the box numpy and scikit-learn version 0.16.1, and I observe the
same behavior – use of more than one core as confirmed by Task Manager. For the
record, I’m using Windows 7 and Anaconda 3.
Am I missing something? Does warm_start allow the use of more than one
processor? Has there been a change in the development tree that affects
parallelization? I’ve searched around for an answer but can’t find anything
relevant. Here is some reproducible code.
The RandomForestClassifier constructor documentation doesn’t address these
concerns. I’m willing to edit the documentation myself once this issue is
clarified.
Thanks.
import sklearn as sk
from pandas import DataFrame
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
import seaborn as sns
import numpy as np
print(sk.__version__)
# Build a classification task using 300 informative features
# There are 100,000 samples
X, y = make_classification(n_samples=100000,
n_features=500,
n_informative=30,
n_redundant=0,
n_repeated=0,
n_classes=2,
random_state=0,
shuffle=False)
forest = RandomForestClassifier(n_jobs=10, random_state=100, oob_score=True,
bootstrap=True, warm_start=True)
n_estimators = 200
rng = range(50, n_estimators + 1, 25)
error_rate = DataFrame(index=np.arange(0, len(rng)), columns=('Number of
Trees', 'OOB Error'))
for i, n_trees in enumerate(rng):
print("Fit training set for {0:d} trees.".format(n_trees))
forest.set_params(n_estimators=n_trees)
forest.set_params(n_jobs=10)
params = forest.get_params()
forest.fit(X, y)
error_rate.loc[i] = [n_trees, 1 - forest.oob_score_]
sns.lmplot('Number of Trees', 'OOB Error',
error_rate).savefig("test_warm_start.png")
print("Finished")
Dale Smith, Ph.D.
Data Scientist
[http://host.msgapp.com/Extranet/96621/Signature%20Images/sig%20logo.png]<http://nexidia.com/>
d. 404.495.7220 x 4008 f. 404.795.7221
Nexidia Corporate | 3565 Piedmont Road, Building Two, Suite 400 | Atlanta, GA
30305
[http://host.msgapp.com/Extranet/96621/Signature%20Images/sig%20Blog.jpeg]<http://blog.nexidia.com/>
[http://host.msgapp.com/Extranet/96621/Signature%20Images/sig%20LinkedIn.jpeg]
<https://www.linkedin.com/company/nexidia>
[http://host.msgapp.com/Extranet/96621/Signature%20Images/sig%20Google.jpeg]
<https://plus.google.com/u/0/107921893643164441840/posts>
[http://host.msgapp.com/Extranet/96621/Signature%20Images/sig%20twitter.jpeg]
<https://twitter.com/Nexidia>
[http://host.msgapp.com/Extranet/96621/Signature%20Images/sig%20Youtube.jpeg]
<https://www.youtube.com/user/NexidiaTV>
------------------------------------------------------------------------------
Monitor 25 network devices or servers for free with OpManager!
OpManager is web-based network management software that monitors
network devices and physical & virtual servers, alerts via email & sms
for fault. Monitor 25 devices for free with no restriction. Download now
http://ad.doubleclick.net/ddm/clk/292181274;119417398;o
_______________________________________________
Scikit-learn-general mailing list
Scikit-learn-general@lists.sourceforge.net
https://lists.sourceforge.net/lists/listinfo/scikit-learn-general