Hi all,
We've been looking at ways to parallelize our classifier training, and we
looked at the n_jobs parameter as a possible way to do that. The classifier
we're currently using, SGDClassifier, supports that parameter, but since we're
using a OneVsRest (Ovr) strategy, our call is wrapped in a OneVsRestClassifier
call, e.g. OneVsRestClassifier(SGDClassifier()). OneVsRestClassifier does not
take an n_jobs parameter, unfortunately; why is this?
Looking at the code, SGDClassifier's fit is implemented like this:
def _fit_multiclass(self, X, y, sample_weight, n_iter):
"""Fit a multi-class classifier by combining binary classifiers
Each binary classifier predicts one class versus all others. This
strategy is called OVA: One Versus All.
"""
# Use joblib to fit OvA in parallel
result = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
delayed(fit_binary)(self, i, X, y, n_iter,
self._expanded_class_weight[i], 1.,
sample_weight)
for i in xrange(len(self.classes_)))
And OneVsRestClassifier's fit is implemented like so:
def fit_ovr(estimator, X, y):
"""Fit a one-vs-the-rest strategy."""
_check_estimator(estimator)
lb = LabelBinarizer()
Y = lb.fit_transform(y)
estimators = [_fit_binary(estimator, X, Y[:, i],
classes=["not %s" % str(i), i])
for i in range(Y.shape[1])]
return estimators, lb
It's easy to see how with some slight modifications (wrapping that in a joblib
Parallel() call) we could enable n_jobs for OneVsRestClassifier. This almost
seems too simple, so there must be a good reason why this isn't done; could
you give your opinions on this?
Thanks,
Afik Cohen
------------------------------------------------------------------------------
Keep yourself connected to Go Parallel:
VERIFY Test and improve your parallel project with help from experts
and peers. http://goparallel.sourceforge.net
_______________________________________________
Scikit-learn-general mailing list
[email protected]
https://lists.sourceforge.net/lists/listinfo/scikit-learn-general