This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7249904622 [AutoScheduler][AutoTVM] Enable xgboost >= 1.7.x new 
changes (#14036)
7249904622 is described below

commit 7249904622e45d37d3d74a11d88929191a3bc622
Author: Balint Cristian <[email protected]>
AuthorDate: Sun Feb 19 02:19:51 2023 +0200

    [AutoScheduler][AutoTVM] Enable xgboost >= 1.7.x new changes (#14036)
    
    Enable xgboost >= 1.7.x new changes
---
 docs/install/from_source.rst                      |   4 +-
 python/gen_requirements.py                        |   2 +-
 python/tvm/auto_scheduler/cost_model/xgb_model.py | 202 ++++++++++++----------
 python/tvm/autotvm/tuner/xgboost_cost_model.py    | 201 ++++++++++++---------
 4 files changed, 238 insertions(+), 171 deletions(-)

diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst
index 37ca72d80f..c2e4c7acd9 100644
--- a/docs/install/from_source.rst
+++ b/docs/install/from_source.rst
@@ -347,7 +347,7 @@ like ``virtualenv``.
 
    .. code:: bash
 
-       pip3 install --user tornado psutil 'xgboost<1.6.0' cloudpickle
+       pip3 install --user tornado psutil 'xgboost>=1.1.0' cloudpickle
 
 Note on M1 macs, you may have trouble installing xgboost / scipy. scipy and 
xgboost requires some additional dependencies to be installed,
 including openblas and its dependencies. Use the following commands to install 
scipy and xgboost with the required dependencies and
@@ -363,7 +363,7 @@ configuration. A workaround for this is to do the following 
commands:
 
         pip install scipy --no-use-pep517
 
-        pip install 'xgboost<1.6.0'
+        pip install 'xgboost>=1.1.0'
 
 Install Contrib Libraries
 -------------------------
diff --git a/python/gen_requirements.py b/python/gen_requirements.py
index 7f5fe57adb..09fb57ee94 100644
--- a/python/gen_requirements.py
+++ b/python/gen_requirements.py
@@ -276,7 +276,7 @@ CONSTRAINTS = [
     ("torch", None),
     ("torchvision", None),
     ("tornado", None),
-    ("xgboost", ">=1.1.0,<1.6.0"),  # From PR #4953 & Issue #12009
+    ("xgboost", ">=1.1.0"),  # From PR #4953 & Issue #12009
 ]
 
 
################################################################################
diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py 
b/python/tvm/auto_scheduler/cost_model/xgb_model.py
index a4e39b9061..328e25db7b 100644
--- a/python/tvm/auto_scheduler/cost_model/xgb_model.py
+++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py
@@ -19,6 +19,7 @@
 """Cost model based on xgboost"""
 import multiprocessing
 import logging
+from typing import Dict
 from collections import defaultdict
 
 import numpy as np
@@ -28,6 +29,14 @@ from .cost_model import PythonBasedModel
 from ..feature import get_per_store_features_from_measure_pairs, 
get_per_store_features_from_states
 from ..measure_record import RecordReader
 
+try:
+    from xgboost.callback import TrainingCallback  # type: ignore
+except ImportError:
+
+    class TrainingCallback:  # type: ignore
+        pass
+
+
 xgb = None
 
 logger = logging.getLogger("auto_scheduler")
@@ -198,7 +207,7 @@ class XGBModel(PythonBasedModel):
             num_boost_round=10000,
             obj=pack_sum_square_error,
             callbacks=[
-                custom_callback(
+                CustomCallback(
                     stopping_rounds=50,
                     metric="tr-p-rmse",
                     fevals=[
@@ -539,125 +548,144 @@ def pack_sum_average_peak_score(N):
     return feval
 
 
-def custom_callback(
-    stopping_rounds,
-    metric,
-    fevals,
-    evals=(),
-    log_file=None,
-    maximize=False,
-    verbose_eval=True,
-    skip_every=2,
-):
-    """Callback function for xgboost to support multiple custom evaluation 
functions"""
-    # pylint: disable=import-outside-toplevel
-    from xgboost.core import EarlyStopException
-    from xgboost.callback import _fmt_metric
-
-    try:
-        from xgboost.training import aggcv
-    except ImportError:
-        from xgboost.callback import _aggcv as aggcv
-
-    state = {}
-    metric_shortname = metric.split("-")[1]
-
-    def init(env):
-        """internal function"""
-        bst = env.model
-
-        state["maximize_score"] = maximize
-        state["best_iteration"] = 0
-        if maximize:
-            state["best_score"] = float("-inf")
-        else:
-            state["best_score"] = float("inf")
+class XGBoostCallback(TrainingCallback):
+    """Base class for XGBoost callbacks."""
 
-        if bst is not None:
-            if bst.attr("best_score") is not None:
-                state["best_score"] = float(bst.attr("best_score"))
-                state["best_iteration"] = int(bst.attr("best_iteration"))
-                state["best_msg"] = bst.attr("best_msg")
-            else:
-                bst.set_attr(best_iteration=str(state["best_iteration"]))
-                bst.set_attr(best_score=str(state["best_score"]))
-        else:
-            assert env.cvfolds is not None
+    def __call__(self, env: "xgb.core.CallbackEnv"):
+        # Compatibility with xgboost < 1.3
+        return self.after_iteration(env.model, env.iteration, 
env.evaluation_result_list)
 
-    def callback(env):
-        """internal function"""
-        if not state:
-            init(env)
+    def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: 
Dict):
+        raise NotImplementedError
+
+
+class CustomCallback(XGBoostCallback):
+    """
+    Callback function for xgboost.
+    Support custom evaluation function and early-stopping.
+    """
+
+    def __init__(
+        self,
+        stopping_rounds,
+        metric,
+        fevals,
+        evals=(),
+        log_file=None,
+        maximize=False,
+        verbose_eval=True,
+        skip_every=2,
+    ):
+        """Init function"""
+        self.stopping_rounds = stopping_rounds
+        self.metric = metric
+        self.metric_shortname = metric.split("-")[1]
+        self.fevals = fevals
+        self.evals = evals
+        self.log_file = log_file
+        self.maximize = maximize
+        self.verbose_eval = verbose_eval
+        self.skip_every = skip_every
+        self.state = {}
 
-        bst = env.model
-        i = env.iteration
-        cvfolds = env.cvfolds
+    def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: 
Dict):
+        """Run after each iteration.  Return True when training should stop."""
+        # pylint:disable = import-outside-toplevel
+        try:
+            from xgboost.callback import _fmt_metric  # type: ignore
+        except ImportError:
+            # Compatibility with xgboost >= 1.6
+            def _fmt_metric(value, show_stdv=True):
+                """format metric string"""
+                if len(value) == 2:
+                    return f"{value[0]}:{value[1]:.5f}"
+                if len(value) == 3:
+                    if show_stdv:
+                        return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
+                    return f"{value[0]}:{value[1]:.5f}"
+                raise ValueError("wrong metric value", value)
+
+        ##### init state #####
+        if not self.state:
+            self.state["maximize_score"] = self.maximize
+            self.state["best_iteration"] = 0
+            if self.maximize:
+                self.state["best_score"] = float("-inf")
+            else:
+                self.state["best_score"] = float("inf")
 
+            assert model is not None
+            if model.attr("best_score") is not None:
+                self.state["best_score"] = float(model.attr("best_score"))
+                self.state["best_iteration"] = 
int(model.attr("best_iteration"))
+                self.state["best_msg"] = model.attr("best_msg")
+            else:
+                
model.set_attr(best_iteration=str(self.state["best_iteration"]))
+                model.set_attr(best_score=str(self.state["best_score"]))
         res_dict = {}
 
-        if i % skip_every == 1:
-            return
+        if epoch % self.skip_every == 1:
+            return False
 
         ##### evaluation #####
-        if cvfolds is not None:
-            for feval in fevals:
-                tmp = aggcv([f.eval(i, feval) for f in cvfolds])
-                for k, mean, std in tmp:
-                    res_dict[k] = [mean, std]
-        else:
-            for feval in fevals:
-                bst_eval = bst.eval_set(evals, i, feval)
-                res = [x.split(":") for x in bst_eval.split()]
-                for kv in res[1:]:
-                    res_dict[kv[0]] = [float(kv[1])]
+        for feval in self.fevals:
+            bst_eval = model.eval_set(self.evals, epoch, feval)
+            res = [x.split(":") for x in bst_eval.split()]
+            for kv in res[1:]:
+                res_dict[kv[0]] = [float(kv[1])]
 
         eval_res = []
         keys = list(res_dict.keys())
-        keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x)
+        keys.sort(key=lambda x: x if self.metric_shortname not in x else "a" + 
x)
         for key in keys:
             v = res_dict[key]
             eval_res.append([key] + v)
 
         ##### print eval result #####
-        if not isinstance(verbose_eval, bool) and verbose_eval and i % 
verbose_eval == 0:
-            infos = ["XGB iter: %3d" % i]
+        if (
+            not isinstance(self.verbose_eval, bool)
+            and self.verbose_eval
+            and epoch % self.verbose_eval == 0
+        ):
+            infos = ["XGB iter: %3d" % epoch]
             for item in eval_res:
                 if "null" in item[0]:
                     continue
                 infos.append("%s: %.6f" % (item[0], item[1]))
 
             logger.debug("\t".join(infos))
-            if log_file:
-                with open(log_file, "a") as fout:
+            if self.log_file:
+                with open(self.log_file, "a") as fout:
                     fout.write("\t".join(infos) + "\n")
 
         ##### choose score and do early stopping #####
         score = None
         for item in eval_res:
-            if item[0] == metric:
+            if item[0] == self.metric:
                 score = item[1]
                 break
         assert score is not None
 
-        best_score = state["best_score"]
-        best_iteration = state["best_iteration"]
-        maximize_score = state["maximize_score"]
+        best_score = self.state["best_score"]
+        best_iteration = self.state["best_iteration"]
+        maximize_score = self.state["maximize_score"]
+
         if (maximize_score and score > best_score) or (not maximize_score and 
score < best_score):
-            msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x 
in eval_res]))
-            state["best_msg"] = msg
-            state["best_score"] = score
-            state["best_iteration"] = env.iteration
+            msg = "[%d] %s" % (epoch, "\t".join([_fmt_metric(x) for x in 
eval_res]))
+            self.state["best_msg"] = msg
+            self.state["best_score"] = score
+            self.state["best_iteration"] = epoch
             # save the property to attributes, so they will occur in 
checkpoint.
-            if env.model is not None:
-                env.model.set_attr(
-                    best_score=str(state["best_score"]),
-                    best_iteration=str(state["best_iteration"]),
-                    best_msg=state["best_msg"],
+            if model is not None:
+                model.set_attr(
+                    best_score=str(self.state["best_score"]),
+                    best_iteration=str(self.state["best_iteration"]),
+                    best_msg=self.state["best_msg"],
                 )
-        elif env.iteration - best_iteration >= stopping_rounds:
-            best_msg = state["best_msg"]
-            if verbose_eval and env.rank == 0:
+        elif epoch - best_iteration >= self.stopping_rounds:
+            best_msg = self.state["best_msg"]
+            if self.verbose_eval:
                 logger.debug("XGB stopped. Best iteration: %s ", best_msg)
-            raise EarlyStopException(best_iteration)
+            return True
 
-    return callback
+        return False
diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py 
b/python/tvm/autotvm/tuner/xgboost_cost_model.py
index 6fa04f336f..a80c350903 100644
--- a/python/tvm/autotvm/tuner/xgboost_cost_model.py
+++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py
@@ -20,6 +20,8 @@
 import logging
 import time
 
+from typing import Dict
+
 import numpy as np
 from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind
 
@@ -28,6 +30,14 @@ from ..utils import get_rank
 from .metric import cover_curve, max_curve, recall_curve
 from .model_based_tuner import CostModel, FeatureCache
 
+try:
+    from xgboost.callback import TrainingCallback  # type: ignore
+except ImportError:
+
+    class TrainingCallback:  # type: ignore
+        pass
+
+
 xgb = None
 
 logger = logging.getLogger("autotvm")
@@ -198,7 +208,7 @@ class XGBoostCostModel(CostModel):
             dtrain,
             num_boost_round=8000,
             callbacks=[
-                custom_callback(
+                CustomCallback(
                     stopping_rounds=20,
                     metric="tr-a-recall@%d" % plan_size,
                     evals=[(dtrain, "tr")],
@@ -282,7 +292,7 @@ class XGBoostCostModel(CostModel):
             dtrain,
             num_boost_round=400,
             callbacks=[
-                custom_callback(
+                CustomCallback(
                     stopping_rounds=100,
                     metric="tr-a-recall@%d" % plan_size,
                     evals=[(dtrain, "tr")],
@@ -443,118 +453,147 @@ def _extract_curve_feature_log(arg):
     return x, y
 
 
-def custom_callback(
-    stopping_rounds, metric, fevals, evals=(), log_file=None, maximize=False, 
verbose_eval=True
-):
-    """callback function for xgboost to support multiple custom evaluation 
functions"""
-    # pylint: disable=import-outside-toplevel
-    from xgboost.callback import _fmt_metric
-    from xgboost.core import EarlyStopException
+class XGBoostCallback(TrainingCallback):
+    """Base class for XGBoost callbacks."""
 
-    try:
-        from xgboost.training import aggcv
-    except ImportError:
-        from xgboost.callback import _aggcv as aggcv
+    def __call__(self, env: "xgb.core.CallbackEnv"):
+        # Compatibility with xgboost < 1.3
+        return self.after_iteration(env.model, env.iteration, 
env.evaluation_result_list)
 
-    state = {}
-    metric_shortname = metric.split("-")[1]
+    def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: 
Dict):
+        raise NotImplementedError
 
-    def init(env):
-        """internal function"""
-        bst = env.model
 
-        state["maximize_score"] = maximize
-        state["best_iteration"] = 0
-        if maximize:
-            state["best_score"] = float("-inf")
-        else:
-            state["best_score"] = float("inf")
+class CustomCallback(XGBoostCallback):
+    """
+    Callback function for xgboost.
+    Support custom evaluation function and early-stopping.
+    """
 
-        if bst is not None:
-            if bst.attr("best_score") is not None:
-                state["best_score"] = float(bst.attr("best_score"))
-                state["best_iteration"] = int(bst.attr("best_iteration"))
-                state["best_msg"] = bst.attr("best_msg")
+    def __init__(
+        self,
+        stopping_rounds,
+        metric,
+        fevals,
+        evals=(),
+        log_file=None,
+        maximize=False,
+        verbose_eval=True,
+        skip_every=2,
+    ):
+        """Init function"""
+        self.stopping_rounds = stopping_rounds
+        self.metric = metric
+        self.metric_shortname = metric.split("-")[1]
+        self.fevals = fevals
+        self.evals = evals
+        self.log_file = log_file
+        self.maximize = maximize
+        self.verbose_eval = verbose_eval
+        self.skip_every = skip_every
+        self.state = {}
+
+    def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: 
Dict):
+        """Run after each iteration.  Return True when training should stop."""
+        # pylint:disable = import-outside-toplevel
+        try:
+            from xgboost.callback import _fmt_metric  # type: ignore
+        except ImportError:
+            # Compatibility with xgboost >= 1.6
+            def _fmt_metric(value, show_stdv=True):
+                """format metric string"""
+                if len(value) == 2:
+                    return f"{value[0]}:{value[1]:.5f}"
+                if len(value) == 3:
+                    if show_stdv:
+                        return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
+                    return f"{value[0]}:{value[1]:.5f}"
+                raise ValueError("wrong metric value", value)
+
+        ##### init state #####
+        if not self.state:
+            self.state["maximize_score"] = self.maximize
+            self.state["best_iteration"] = 0
+            if self.maximize:
+                self.state["best_score"] = float("-inf")
             else:
-                bst.set_attr(best_iteration=str(state["best_iteration"]))
-                bst.set_attr(best_score=str(state["best_score"]))
-        else:
-            assert env.cvfolds is not None
-
-    def callback(env):
-        """internal function"""
-        if not state:
-            init(env)
-
-        bst = env.model
-        i = env.iteration
-        cvfolds = env.cvfolds
+                self.state["best_score"] = float("inf")
 
+            assert model is not None
+            if model.attr("best_score") is not None:
+                self.state["best_score"] = float(model.attr("best_score"))
+                self.state["best_iteration"] = 
int(model.attr("best_iteration"))
+                self.state["best_msg"] = model.attr("best_msg")
+            else:
+                
model.set_attr(best_iteration=str(self.state["best_iteration"]))
+                model.set_attr(best_score=str(self.state["best_score"]))
         res_dict = {}
 
+        if epoch % self.skip_every == 1:
+            return False
+
         ##### evaluation #####
-        if cvfolds is not None:
-            for feval in fevals:
-                tmp = aggcv([f.eval(i, feval) for f in cvfolds])
-                for k, mean, std in tmp:
-                    res_dict[k] = [mean, std]
-        else:
-            for feval in fevals:
-                bst_eval = bst.eval_set(evals, i, feval)
-                res = [x.split(":") for x in bst_eval.split()]
-                for kv in res[1:]:
-                    res_dict[kv[0]] = [float(kv[1])]
+        for feval in self.fevals:
+            bst_eval = model.eval_set(self.evals, epoch, feval)
+            res = [x.split(":") for x in bst_eval.split()]
+            for kv in res[1:]:
+                res_dict[kv[0]] = [float(kv[1])]
 
         eval_res = []
         keys = list(res_dict.keys())
-        keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x)
+        keys.sort(key=lambda x: x if self.metric_shortname not in x else "a" + 
x)
         for key in keys:
             v = res_dict[key]
             eval_res.append([key] + v)
 
         ##### print eval result #####
-        infos = ["XGB iter: %3d" % i]
-        for item in eval_res:
-            if "null" in item[0]:
-                continue
-            infos.append("%s: %.6f" % (item[0], item[1]))
+        if (
+            not isinstance(self.verbose_eval, bool)
+            and self.verbose_eval
+            and epoch % self.verbose_eval == 0
+        ):
+            infos = ["XGB iter: %3d" % epoch]
+            for item in eval_res:
+                if "null" in item[0]:
+                    continue
+                infos.append("%s: %.6f" % (item[0], item[1]))
 
-        if not isinstance(verbose_eval, bool) and verbose_eval and i % 
verbose_eval == 0:
             logger.debug("\t".join(infos))
-        if log_file:
-            with open(log_file, "a") as fout:
-                fout.write("\t".join(infos) + "\n")
+            if self.log_file:
+                with open(self.log_file, "a") as fout:
+                    fout.write("\t".join(infos) + "\n")
 
         ##### choose score and do early stopping #####
         score = None
         for item in eval_res:
-            if item[0] == metric:
+            if item[0] == self.metric:
                 score = item[1]
                 break
         assert score is not None
 
-        best_score = state["best_score"]
-        best_iteration = state["best_iteration"]
-        maximize_score = state["maximize_score"]
+        best_score = self.state["best_score"]
+        best_iteration = self.state["best_iteration"]
+        maximize_score = self.state["maximize_score"]
+
         if (maximize_score and score > best_score) or (not maximize_score and 
score < best_score):
-            msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x 
in eval_res]))
-            state["best_msg"] = msg
-            state["best_score"] = score
-            state["best_iteration"] = env.iteration
+            msg = "[%d] %s" % (epoch, "\t".join([_fmt_metric(x) for x in 
eval_res]))
+            self.state["best_msg"] = msg
+            self.state["best_score"] = score
+            self.state["best_iteration"] = epoch
             # save the property to attributes, so they will occur in 
checkpoint.
-            if env.model is not None:
-                env.model.set_attr(
-                    best_score=str(state["best_score"]),
-                    best_iteration=str(state["best_iteration"]),
-                    best_msg=state["best_msg"],
+            if model is not None:
+                model.set_attr(
+                    best_score=str(self.state["best_score"]),
+                    best_iteration=str(self.state["best_iteration"]),
+                    best_msg=self.state["best_msg"],
                 )
-        elif env.iteration - best_iteration >= stopping_rounds:
-            best_msg = state["best_msg"]
-            if verbose_eval and env.rank == 0:
+        elif epoch - best_iteration >= self.stopping_rounds:
+            best_msg = self.state["best_msg"]
+            if self.verbose_eval:
                 logger.debug("XGB stopped. Best iteration: %s ", best_msg)
-            raise EarlyStopException(best_iteration)
+            return True
 
-    return callback
+        return False
 
 
 # feval wrapper for xgboost

Reply via email to