You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/21 06:04:21 UTC

[GitHub] [tvm] Sunny-Island commented on a diff in pull request #12144: [Auto Scheduler] Upgrade autoscheduler xgboost callback

Sunny-Island commented on code in PR #12144:
URL: https://github.com/apache/tvm/pull/12144#discussion_r926284242


##########
python/tvm/auto_scheduler/cost_model/xgb_model.py:
##########
@@ -539,125 +539,128 @@ def feval(preds, labels):
     return feval
 
 
-def custom_callback(
-    stopping_rounds,
-    metric,
-    fevals,
-    evals=(),
-    log_file=None,
-    maximize=False,
-    verbose_eval=True,
-    skip_every=2,
-):
-    """Callback function for xgboost to support multiple custom evaluation functions"""
-    # pylint: disable=import-outside-toplevel
-    from xgboost.core import EarlyStopException
-    from xgboost.callback import _fmt_metric
-
-    try:
-        from xgboost.training import aggcv
-    except ImportError:
-        from xgboost.callback import _aggcv as aggcv
-
-    state = {}
-    metric_shortname = metric.split("-")[1]
-
-    def init(env):
-        """internal function"""
-        bst = env.model
-
-        state["maximize_score"] = maximize
-        state["best_iteration"] = 0
-        if maximize:
-            state["best_score"] = float("-inf")
-        else:
-            state["best_score"] = float("inf")
+class CustomCallback(callback.TrainingCallback):
+    """
+    Callback function for xgboost.
+    Support custom evaluation function and early-stopping.
+    """
 
-        if bst is not None:
-            if bst.attr("best_score") is not None:
-                state["best_score"] = float(bst.attr("best_score"))
-                state["best_iteration"] = int(bst.attr("best_iteration"))
-                state["best_msg"] = bst.attr("best_msg")
+    def __init__(
+        self,
+        stopping_rounds,
+        metric,
+        fevals,
+        evals=(),
+        log_file=None,
+        maximize=False,
+        verbose_eval=True,
+        skip_every=2,
+    ):
+        """Init function"""
+        self.stopping_rounds = stopping_rounds
+        self.metric = metric
+        self.metric_shortname = metric.split("-")[1]
+        self.fevals = fevals
+        self.evals = evals
+        self.log_file = log_file
+        self.maximize = maximize
+        self.verbose_eval = verbose_eval
+        self.skip_every = skip_every
+        self.state = {}
+
+    def after_iteration(self, model, epoch, _evals_log):
+        """Run after each iteration.  Return True when training should stop."""
+        ##### init state #####
+        if not self.state:
+            self.state["maximize_score"] = self.maximize
+            self.state["best_iteration"] = 0
+            if self.maximize:
+                self.state["best_score"] = float("-inf")
             else:
-                bst.set_attr(best_iteration=str(state["best_iteration"]))
-                bst.set_attr(best_score=str(state["best_score"]))
-        else:
-            assert env.cvfolds is not None
-
-    def callback(env):
-        """internal function"""
-        if not state:
-            init(env)
-
-        bst = env.model
-        i = env.iteration
-        cvfolds = env.cvfolds
+                self.state["best_score"] = float("inf")
 
+            assert model is not None
+            if model.attr("best_score") is not None:
+                self.state["best_score"] = float(model.attr("best_score"))
+                self.state["best_iteration"] = int(model.attr("best_iteration"))
+                self.state["best_msg"] = model.attr("best_msg")
+            else:
+                model.set_attr(best_iteration=str(self.state["best_iteration"]))
+                model.set_attr(best_score=str(self.state["best_score"]))
         res_dict = {}
 
-        if i % skip_every == 1:
-            return
+        if epoch % self.skip_every == 1:
+            return False
 
         ##### evaluation #####
-        if cvfolds is not None:
-            for feval in fevals:
-                tmp = aggcv([f.eval(i, feval) for f in cvfolds])
-                for k, mean, std in tmp:
-                    res_dict[k] = [mean, std]
-        else:
-            for feval in fevals:
-                bst_eval = bst.eval_set(evals, i, feval)
-                res = [x.split(":") for x in bst_eval.split()]
-                for kv in res[1:]:
-                    res_dict[kv[0]] = [float(kv[1])]
+        for feval in self.fevals:
+            bst_eval = model.eval_set(self.evals, epoch, feval)
+            res = [x.split(":") for x in bst_eval.split()]
+            for kv in res[1:]:
+                res_dict[kv[0]] = [float(kv[1])]
 
         eval_res = []
         keys = list(res_dict.keys())
-        keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x)
+        keys.sort(key=lambda x: x if self.metric_shortname not in x else "a" + x)
         for key in keys:
             v = res_dict[key]
             eval_res.append([key] + v)
 
         ##### print eval result #####
-        if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
-            infos = ["XGB iter: %3d" % i]
+        if (
+            not isinstance(self.verbose_eval, bool)
+            and self.verbose_eval
+            and epoch % self.verbose_eval == 0
+        ):
+            infos = ["XGB iter: %3d" % epoch]
             for item in eval_res:
                 if "null" in item[0]:
                     continue
                 infos.append("%s: %.6f" % (item[0], item[1]))
 
             logger.debug("\t".join(infos))
-            if log_file:
-                with open(log_file, "a") as fout:
+            if self.log_file:
+                with open(self.log_file, "a") as fout:
                     fout.write("\t".join(infos) + "\n")
 
         ##### choose score and do early stopping #####
         score = None
         for item in eval_res:
-            if item[0] == metric:
+            if item[0] == self.metric:
                 score = item[1]
                 break
         assert score is not None
 
-        best_score = state["best_score"]
-        best_iteration = state["best_iteration"]
-        maximize_score = state["maximize_score"]
+        best_score = self.state["best_score"]
+        best_iteration = self.state["best_iteration"]
+        maximize_score = self.state["maximize_score"]
+
         if (maximize_score and score > best_score) or (not maximize_score and score < best_score):
-            msg = "[%d] %s" % (env.iteration, "\t".join([_fmt_metric(x) for x in eval_res]))
-            state["best_msg"] = msg
-            state["best_score"] = score
-            state["best_iteration"] = env.iteration
+            msg = "[%d] %s" % (epoch, "\t".join([self._fmt_metric(x) for x in eval_res]))
+            self.state["best_msg"] = msg
+            self.state["best_score"] = score
+            self.state["best_iteration"] = epoch
             # save the property to attributes, so they will occur in checkpoint.
-            if env.model is not None:
-                env.model.set_attr(
-                    best_score=str(state["best_score"]),
-                    best_iteration=str(state["best_iteration"]),
-                    best_msg=state["best_msg"],
+            if model is not None:
+                model.set_attr(
+                    best_score=str(self.state["best_score"]),
+                    best_iteration=str(self.state["best_iteration"]),
+                    best_msg=self.state["best_msg"],
                 )
-        elif env.iteration - best_iteration >= stopping_rounds:
-            best_msg = state["best_msg"]
-            if verbose_eval and env.rank == 0:
+        elif epoch - best_iteration >= self.stopping_rounds:
+            best_msg = self.state["best_msg"]
+            if self.verbose_eval:
                 logger.debug("XGB stopped. Best iteration: %s ", best_msg)
-            raise EarlyStopException(best_iteration)
-
-    return callback
+            return True
+
+        return False
+
+    def _fmt_metric(self, value, show_stdv=True):

Review Comment:
   Thank you for your reply! I will resolve these two suggestion before tomorrow.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org