zxybazh commented on code in PR #12141:
URL: https://github.com/apache/tvm/pull/12141#discussion_r931349531


##########
python/tvm/meta_schedule/cost_model/xgb_model.py:
##########
@@ -35,7 +35,15 @@
 from ..utils import cpu_count, derived_object, shash2hex
 from .metric import max_curve
 
+
 if TYPE_CHECKING:
+    try:

Review Comment:
   I see. Here we put it under `TYPE_CHECKING` it would not be imported in CI 
or local runs and it would only work for type hints, just like the xgb library 
imported here. If we need to use it as a class to inherit from, we would need 
to import it again later just like this
   ```
   import xgboost as xgb  # type: ignore # pylint: 
disable=import-outside-toplevel
   ```
   In this case we would need to import `TrainingCallback` class and we don't 
actually need to do type hint with it so my suggestion is we don't need it 
under `TYPE_CHECKING`. Instead, we can do lazy import using a decorator 
following this thread: 
https://stackoverflow.com/questions/14879206/lazy-load-configure-a-class-to-inherit-from
   
   Let me know if this works for you : )



##########
python/tvm/meta_schedule/cost_model/xgb_model.py:
##########
@@ -763,3 +768,162 @@ def callback(env: "xgb.core.CallbackEnv"):
             raise EarlyStopException(best_iteration)
 
     return callback
+
+
+class XGBoostCallback(TrainingCallback):
+    """Base class for XGBoost callbacks."""
+
+    def __call__(self, env: "xgb.core.CallbackEnv"):
+        # Compatibility with xgboost < 1.3
+        return self.after_iteration(env.model, env.iteration, 
env.evaluation_result_list)
+
+    def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: 
Dict):
+        raise NotImplementedError
+
+
+class XGBoostCustomCallback(XGBoostCallback):
+    """Custom callback class for xgboost to support multiple custom evaluation 
functions"""
+
+    def __init__(
+        self,
+        early_stopping_rounds: int,
+        verbose_eval: int,
+        fevals: List[Callable],
+        evals: List[Tuple["xgb.DMatrix", str]],
+        focused_metric: str = "tr-p-rmse",
+        cvfolds: List["xgb.training.CVPack"] = None,
+    ):
+        self.early_stopping_rounds = early_stopping_rounds
+        self.verbose_eval = verbose_eval
+        self.fevals = fevals
+        self.evals = evals
+        self.state: Dict[str, Any] = {}
+        self.focused_metric = focused_metric
+        self.sort_key = make_metric_sorter(focused_metric=focused_metric)
+        self.cvfolds = cvfolds
+        if cvfolds is not None:
+            self.aggregated_cv = None
+
+    def init(self, model: "xgb.Booster"):
+        """Internal function for intialization"""
+        booster: "xgb.Booster" = model
+        self.state["best_iteration"] = 0
+        self.state["best_score"] = float("inf")
+        if booster is None:
+            assert self.cvfolds is not None
+            return
+        if booster.attr("best_score") is not None:
+            self.state["best_score"] = float(booster.attr("best_score"))
+            self.state["best_iteration"] = int(booster.attr("best_iteration"))
+            self.state["best_msg"] = booster.attr("best_msg")
+        else:
+            booster.set_attr(best_iteration=str(self.state["best_iteration"]))
+            booster.set_attr(best_score=str(self.state["best_score"]))
+
+    def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: 
Dict):
+        """Internal function for after_iteration"""
+        # pylint:disable = import-outside-toplevel

Review Comment:
   I think you can safely enable it when `aggcv` is imported right?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to