You don't have anywhere in your class MyClassifier where you are
calling base_classifier.fit(...) therefore when calling
base_classifier.predict(...) it will let you know that you did not fit
it.

On Wed, 19 Sep 2018 at 16:43, Luiz Gustavo Hafemann <luiz...@gmail.com> wrote:
>
> Hello,
>
> I am one of the developers of a library for Dynamic Ensemble Selection (DES) 
> methods (the library is called DESlib), and we are currently working to get 
> the library fully compatible with scikit-learn (to submit it to 
> scikit-learn-contrib). We have "check_estimator" working for most of the 
> classes, but now I am having problems to make the classes compatible with 
> GridSearch / other CV functions.
>
> One of the main use cases of this library is to facilitate research on this 
> field, and this led to a design decision that the base classifiers are fit by 
> the user, and the DES methods receive a pool of base classifiers that were 
> already fit (this allow users to compare many DES techniques with the same 
> base classifiers). This is creating an issue with GridSearch, since the clone 
> method (defined in sklearn.base) is not cloning the classes as we would like. 
> It does a shallow (non-deep) copy of the parameters, but we would like the 
> pool of base classifiers to be deep-copied.
>
> I analyzed this issue and I could not find a solution that does not require 
> changes on the scikit-learn code. Here is the sequence of steps that cause 
> the problem:
>
> GridSearchCV calls "clone" on the DES estimator (link)
> The clone function calls the "get_params" function of the DES estimator 
> (link, line 60). We don't re-implement this function, so it gets all the 
> parameters, including the pool of classifiers (at this point, they are still 
> "fitted")
> The clone function then clones each parameter with safe=False (line 62). When 
> cloning the pool of classifiers, the result is a pool that is not "fitted" 
> anymore.
>
> The problem is that, to my knowledge, there is no way for my classifier to 
> inform "clone" that a parameter should be always deep copied. I see that 
> other ensemble methods in sklearn always fit the base classifiers within the 
> "fit" method of the ensemble, so this problem does not happen there. I would 
> like to know if there is a solution for this problem while having the base 
> classifiers fitted elsewhere.
>
> Here is a short code that reproduces the issue:
>
> ---------------------------
>
> from sklearn.model_selection import GridSearchCV, train_test_split
> from sklearn.base import BaseEstimator, ClassifierMixin
> from sklearn.ensemble import BaggingClassifier
> from sklearn.datasets import load_iris
>
>
> class MyClassifier(BaseEstimator, ClassifierMixin):
>     def __init__(self, base_classifiers, k):
>         self.base_classifiers = base_classifiers  # Base classifiers that are 
> already trained
>         self.k = k  # Simulate a parameter that we want to do a grid search on
>
>     def fit(self, X_dsel, y_dsel):
>         pass  # Here we would fit any parameters for the Dynamic selection 
> method, not the base classifiers
>
>     def predict(self, X):
>         return self.base_classifiers.predict(X)  # In practice the methods 
> would do something with the predictions of each classifier
>
>
> X, y = load_iris(return_X_y=True)
> X_train, X_dsel, y_train, y_dsel = train_test_split(X, y, test_size=0.5)
>
> base_classifiers = BaggingClassifier()
> base_classifiers.fit(X_train, y_train)
>
> clf = MyClassifier(base_classifiers, k=1)
>
> params = {'k': [1, 3, 5, 7]}
> grid = GridSearchCV(clf, params)
>
> grid.fit(X_dsel, y_dsel)  # Raises error that the bagging classifiers are not 
> fitted
>
> ---------------------------
>
> Btw, here is the branch that we are using to make the library compatible with 
> sklearn: https://github.com/Menelau/DESlib/tree/sklearn-estimators. The 
> failing test related to this issue is in 
> https://github.com/Menelau/DESlib/blob/sklearn-estimators/deslib/tests/test_des_integration.py#L36
>
> Thanks in advance for any help on this case,
>
> Luiz Gustavo Hafemann
>
> _______________________________________________
> scikit-learn mailing list
> scikit-learn@python.org
> https://mail.python.org/mailman/listinfo/scikit-learn



-- 
Guillaume Lemaitre
INRIA Saclay - Parietal team
Center for Data Science Paris-Saclay
https://glemaitre.github.io/
_______________________________________________
scikit-learn mailing list
scikit-learn@python.org
https://mail.python.org/mailman/listinfo/scikit-learn

Reply via email to