zxybazh commented on code in PR #12141: URL: https://github.com/apache/tvm/pull/12141#discussion_r931349531
########## python/tvm/meta_schedule/cost_model/xgb_model.py: ########## @@ -35,7 +35,15 @@ from ..utils import cpu_count, derived_object, shash2hex from .metric import max_curve + if TYPE_CHECKING: + try: Review Comment: I see. Here we put it under `TYPE_CHECKING` it would not be imported in CI or local runs and it would only work for type hints, just like the xgb library imported here. If we need to use it as a class to inherit from, we would need to import it again later just like this ``` import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel ``` In this case we would need to import `TrainingCallback` class and we don't actually need to do type hint with it so my suggestion is we don't need it under `TYPE_CHECKING`. Instead, we can do lazy import using a decorator following this thread: https://stackoverflow.com/questions/14879206/lazy-load-configure-a-class-to-inherit-from Let me know if this works for you : ) ########## python/tvm/meta_schedule/cost_model/xgb_model.py: ########## @@ -763,3 +768,162 @@ def callback(env: "xgb.core.CallbackEnv"): raise EarlyStopException(best_iteration) return callback + + +class XGBoostCallback(TrainingCallback): + """Base class for XGBoost callbacks.""" + + def __call__(self, env: "xgb.core.CallbackEnv"): + # Compatibility with xgboost < 1.3 + return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) + + def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): + raise NotImplementedError + + +class XGBoostCustomCallback(XGBoostCallback): + """Custom callback class for xgboost to support multiple custom evaluation functions""" + + def __init__( + self, + early_stopping_rounds: int, + verbose_eval: int, + fevals: List[Callable], + evals: List[Tuple["xgb.DMatrix", str]], + focused_metric: str = "tr-p-rmse", + cvfolds: List["xgb.training.CVPack"] = None, + ): + self.early_stopping_rounds = early_stopping_rounds + self.verbose_eval = verbose_eval + self.fevals = fevals + self.evals = evals + self.state: Dict[str, Any] = {} + self.focused_metric = focused_metric + self.sort_key = make_metric_sorter(focused_metric=focused_metric) + self.cvfolds = cvfolds + if cvfolds is not None: + self.aggregated_cv = None + + def init(self, model: "xgb.Booster"): + """Internal function for intialization""" + booster: "xgb.Booster" = model + self.state["best_iteration"] = 0 + self.state["best_score"] = float("inf") + if booster is None: + assert self.cvfolds is not None + return + if booster.attr("best_score") is not None: + self.state["best_score"] = float(booster.attr("best_score")) + self.state["best_iteration"] = int(booster.attr("best_iteration")) + self.state["best_msg"] = booster.attr("best_msg") + else: + booster.set_attr(best_iteration=str(self.state["best_iteration"])) + booster.set_attr(best_score=str(self.state["best_score"])) + + def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): + """Internal function for after_iteration""" + # pylint:disable = import-outside-toplevel Review Comment: I think you can safely enable it when `aggcv` is imported right? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
