You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/06 20:49:16 UTC

[GitHub] [tvm] junrushao1994 opened a new pull request #9859: [MetaSchedule] XGB-based Cost Model

junrushao1994 opened a new pull request #9859:
URL: https://github.com/apache/tvm/pull/9859


   This PR is part of the stage M3c of the meta schedule project (#8473).
   
   The architecture is re-designed by Junru and Xiyou. In this PR we introduced a XGB-based cost model based on meta schedule's cost model interface. Unittests are included.
   
   Thanks to all co-authors for contributing!
   
   Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
   Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
   Co-authored-by: Ruihang Lai <la...@qq.com>
   Co-authored-by: Hongyi Jin <32...@qq.com>
   Co-authored-by: Wuwei Lin <wu...@apache.org>
   Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9859: [MetaSchedule] XGB-based Cost Model

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9859:
URL: https://github.com/apache/tvm/pull/9859#discussion_r779913135



##########
File path: python/tvm/meta_schedule/cost_model/xgb_model.py
##########
@@ -0,0 +1,680 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+XGBoost-based cost model
+"""
+from itertools import chain as itertools_chain
+import logging
+import os
+import tempfile
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, TYPE_CHECKING, Tuple
+
+import numpy as np
+
+from ...contrib.tar import tar, untar
+from ..cost_model import PyCostModel
+from ..feature_extractor import FeatureExtractor
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..utils import cpu_count
+from .metric import max_curve
+
+if TYPE_CHECKING:
+    from ..tune_context import TuneContext
+    import xgboost as xgb
+
+
+logger = logging.getLogger(__name__)
+
+
+def make_metric_sorter(focused_metric):
+    """ Make sure the focused metric is the first one. """
+
+    def metric_name_for_sort(name):
+        if focused_metric == name:
+            return "!" + name
+        return name
+
+    def sort_key(key):
+        key, _ = key
+        return metric_name_for_sort(key)
+
+    return sort_key
+
+
+class PackSum:
+    """The pack-sum format
+
+    Parameters
+    ----------
+    dmatrix : xgb.DMatrix
+        A float64 array of shape [n, m],
+        where `n` is the packed number of blocks,
+        and `m` is the length of feature vector on each block
+    ids : np.ndarray
+        An int64 array of shape [n] containing nonnegative integers,
+        indicating which the index of a sample that a block belongs to
+    """
+
+    dmatrix: "xgb.DMatrix"  # type: ignore # pylint: disable=invalid-name
+    ids: np.ndarray
+
+    def __init__(
+        self,
+        xs: List[np.ndarray],
+        ys: Optional[np.ndarray],
+    ):
+        """Create PackSum format given a batch of samples
+
+        Parameters
+        ----------
+        xs : List[np.ndarray]
+            A batch of input samples
+        ys : Optional[List[float]]
+            A batch of labels. None means no labels available.
+        """
+        import xgboost as xgb  # type: ignore # pylint: disable=import-outside-toplevel
+
+        repeats = [x.shape[0] for x in xs]
+        xs = np.concatenate(xs, axis=0)
+        self.ids = np.concatenate([[i] * repeat for i, repeat in enumerate(repeats)], axis=0)
+        if ys is None:
+            self.dmatrix = xgb.DMatrix(data=xs, label=None)
+        else:
+            ys = np.concatenate([[y] * repeat for y, repeat in zip(ys, repeats)], axis=0)
+            self.dmatrix = xgb.DMatrix(data=xs, label=ys)
+            self.dmatrix.set_weight(ys)
+
+    def predict_with_score(self, pred: np.ndarray) -> np.ndarray:
+        """Predict the labels given the block level prediction scores.
+
+        Parameters
+        ----------
+        pred : np.ndarray
+            The block level predictions
+
+        Returns
+        -------
+        result : np.ndarray
+            The predictions for each candidate.
+        """
+        return np.bincount(self.ids, weights=pred)
+
+    def obj_square_error(self, ys_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+        """Implement square error loss on pack-sum format as
+        a custom objective function for xgboost.
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The predictions
+
+        Returns
+        -------
+        gradient: np.ndarray
+            The gradient according to the xgboost format
+        hessian: np.ndarray
+            The hessian according to the xgboost format
+        """
+        # Making prediction
+        ys_pred = self.predict_with_score(ys_pred)
+        # Propagate prediction to each block
+        ys_pred = ys_pred[self.ids]
+        # The gradient and hessian
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        gradient = ys_pred - ys
+        hessian = np.ones_like(gradient)
+        return gradient * ys, hessian * ys
+
+    def rmse(self, ys_pred: np.ndarray) -> Tuple[str, float]:
+        """Evaluate RMSE (rooted mean square error) in the pack-sum format
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The raw predictions
+
+        Returns
+        -------
+        name: str
+            The name of the metric
+        score: float
+            The score of the metric
+        """
+        # Making prediction
+        ys_pred = self.predict_with_score(ys_pred)
+        # Propagate prediction to each block
+        ys_pred = ys_pred[self.ids]
+        # The RMSE
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        square_error = np.square(ys_pred - ys)
+        rmse = np.sqrt(square_error.mean())
+        return "p-rmse", rmse
+
+    def average_peak_score(
+        self,
+        ys_pred: np.ndarray,
+        n: int,
+    ) -> Tuple[str, float]:
+        """Evaluate average-peak-score@N in the pack-sum format
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The raw prediction
+        n : int
+            The N in average-peak-score@N
+
+        Returns
+        -------
+        name: str
+            The name of the metric
+        score: float
+            The score of the metric
+        """
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        ys = self.predict_with_score(ys)  # type: ignore # pylint: disable=invalid-name
+        ys = ys / np.unique(self.ids, return_counts=True)[1]  # type: ignore # pylint: disable=invalid-name
+        ys_pred = self.predict_with_score(ys_pred)
+        trials = np.argsort(ys_pred)[::-1][:n]
+        trial_scores = ys[trials]
+        curve = max_curve(trial_scores) / np.max(ys)
+        score = np.mean(curve)
+        return f"a-peak@{n}", score
+
+
+class XGBConfig(NamedTuple):
+    """XGBoost model configuration
+
+    Parameters
+    ----------
+    max_depth : int
+        The maximum depth.
+    gamma : float
+        The gamma.
+    min_child_weight : float
+        The minimum child weight.
+    eta : float
+        The eta, learning rate.
+    seed : int
+        The random seed.
+    nthread : Optional[int],
+        The number of threads to use.
+        Default is None, which means to use physical number of cores.
+    """
+
+    def to_dict(self):
+        xgb_params = {
+            "max_depth": self.max_depth,
+            "gamma": self.gamma,
+            "min_child_weight": self.min_child_weight,
+            "eta": self.eta,
+            "seed": self.seed,
+            "nthread": self.nthread,
+        }
+        return xgb_params
+
+    max_depth: int = 10
+    gamma: float = 0.001
+    min_child_weight: float = 0
+    eta: float = 0.2
+    seed: int = 43
+    nthread: Optional[int] = None
+
+
+class XGBModel(PyCostModel):
+    """XGBoost model
+
+    Parameters
+    ----------
+    extractor : FeatureExtractor
+        The feature extractor for the model.
+    config : XGBConfig
+        The XGBoost model config.
+    num_warmup_samples : int
+        The number of samples that are used for warmup, i.e., the first few samples are predicted
+        with random results.
+    early_stopping_rounds : int
+        The number of rounds for early stopping.
+    verbose_eval : int
+        The verbose level when doing evaluation.
+    average_peak_n : int
+        The number to calculate average peak score.
+    """
+
+    # feature extractor
+    extractor: FeatureExtractor
+    # xgboost model config
+    config: XGBConfig
+    # behavior of randomness
+    num_warmup_samples: int
+    # evaluation
+    early_stopping_rounds: int
+    verbose_eval: int
+    average_peak_n: int
+    # states
+    cached_features: List[np.ndarray]
+    cached_mean_costs: np.ndarray
+    cached_normalizer: Optional[float]
+    booster: Optional["xgb.Booster"]
+
+    def __init__(
+        self,
+        *,
+        # feature extractor
+        extractor: FeatureExtractor,
+        # xgboost model config
+        config: XGBConfig = XGBConfig(),
+        # behavior of randomness
+        num_warmup_samples: int = 100,
+        # evaluation
+        early_stopping_rounds: int = 50,
+        verbose_eval: int = 25,
+        average_peak_n: int = 32,
+    ):
+        super().__init__()
+        # feature extractor
+        self.extractor = extractor
+        # model-related
+        if config.nthread is None:
+            # use physical core number
+            config = config._replace(nthread=cpu_count(logical=False))
+        self.config = config
+        # behavior of randomness
+        self.num_warmup_samples = num_warmup_samples
+        # evaluation
+        self.early_stopping_rounds = early_stopping_rounds
+        self.verbose_eval = verbose_eval
+        self.average_peak_n = average_peak_n
+        # states
+        self.cached_features = []
+        self.cached_mean_costs = np.empty((0,), dtype="float64")
+        self.cached_normalizer = None
+        self.booster = None
+
+    def load(self, path: str) -> None:
+        """Load the cost model from given file location.
+
+        Parameters
+        ----------
+        path : str
+            The file path.
+
+        Note
+        ----
+        Since XGBoost model trains from scratch, each time we can only load the model without the
+        previous cached features / results so any call of update won't use previous training data.

Review comment:
       Sorry, this comment is wrong (CC @zxybazh). Actually we did load/store the cached feature vectors to make sure all the existing data are used for training.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9859: [MetaSchedule] XGB-based Cost Model

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9859:
URL: https://github.com/apache/tvm/pull/9859#issuecomment-1006925131


   CC: @comaniac @yzhliu @merrymercy 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zxybazh commented on a change in pull request #9859: [MetaSchedule] XGB-based Cost Model

Posted by GitBox <gi...@apache.org>.
zxybazh commented on a change in pull request #9859:
URL: https://github.com/apache/tvm/pull/9859#discussion_r779914252



##########
File path: python/tvm/meta_schedule/cost_model/xgb_model.py
##########
@@ -0,0 +1,680 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+XGBoost-based cost model
+"""
+from itertools import chain as itertools_chain
+import logging
+import os
+import tempfile
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, TYPE_CHECKING, Tuple
+
+import numpy as np
+
+from ...contrib.tar import tar, untar
+from ..cost_model import PyCostModel
+from ..feature_extractor import FeatureExtractor
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..utils import cpu_count
+from .metric import max_curve
+
+if TYPE_CHECKING:
+    from ..tune_context import TuneContext
+    import xgboost as xgb
+
+
+logger = logging.getLogger(__name__)
+
+
+def make_metric_sorter(focused_metric):
+    """ Make sure the focused metric is the first one. """
+
+    def metric_name_for_sort(name):
+        if focused_metric == name:
+            return "!" + name
+        return name
+
+    def sort_key(key):
+        key, _ = key
+        return metric_name_for_sort(key)
+
+    return sort_key
+
+
+class PackSum:
+    """The pack-sum format
+
+    Parameters
+    ----------
+    dmatrix : xgb.DMatrix
+        A float64 array of shape [n, m],
+        where `n` is the packed number of blocks,
+        and `m` is the length of feature vector on each block
+    ids : np.ndarray
+        An int64 array of shape [n] containing nonnegative integers,
+        indicating which the index of a sample that a block belongs to
+    """
+
+    dmatrix: "xgb.DMatrix"  # type: ignore # pylint: disable=invalid-name
+    ids: np.ndarray
+
+    def __init__(
+        self,
+        xs: List[np.ndarray],
+        ys: Optional[np.ndarray],
+    ):
+        """Create PackSum format given a batch of samples
+
+        Parameters
+        ----------
+        xs : List[np.ndarray]
+            A batch of input samples
+        ys : Optional[List[float]]
+            A batch of labels. None means no labels available.
+        """
+        import xgboost as xgb  # type: ignore # pylint: disable=import-outside-toplevel
+
+        repeats = [x.shape[0] for x in xs]
+        xs = np.concatenate(xs, axis=0)
+        self.ids = np.concatenate([[i] * repeat for i, repeat in enumerate(repeats)], axis=0)
+        if ys is None:
+            self.dmatrix = xgb.DMatrix(data=xs, label=None)
+        else:
+            ys = np.concatenate([[y] * repeat for y, repeat in zip(ys, repeats)], axis=0)
+            self.dmatrix = xgb.DMatrix(data=xs, label=ys)
+            self.dmatrix.set_weight(ys)
+
+    def predict_with_score(self, pred: np.ndarray) -> np.ndarray:
+        """Predict the labels given the block level prediction scores.
+
+        Parameters
+        ----------
+        pred : np.ndarray
+            The block level predictions
+
+        Returns
+        -------
+        result : np.ndarray
+            The predictions for each candidate.
+        """
+        return np.bincount(self.ids, weights=pred)
+
+    def obj_square_error(self, ys_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+        """Implement square error loss on pack-sum format as
+        a custom objective function for xgboost.
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The predictions
+
+        Returns
+        -------
+        gradient: np.ndarray
+            The gradient according to the xgboost format
+        hessian: np.ndarray
+            The hessian according to the xgboost format
+        """
+        # Making prediction
+        ys_pred = self.predict_with_score(ys_pred)
+        # Propagate prediction to each block
+        ys_pred = ys_pred[self.ids]
+        # The gradient and hessian
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        gradient = ys_pred - ys
+        hessian = np.ones_like(gradient)
+        return gradient * ys, hessian * ys
+
+    def rmse(self, ys_pred: np.ndarray) -> Tuple[str, float]:
+        """Evaluate RMSE (rooted mean square error) in the pack-sum format
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The raw predictions
+
+        Returns
+        -------
+        name: str
+            The name of the metric
+        score: float
+            The score of the metric
+        """
+        # Making prediction
+        ys_pred = self.predict_with_score(ys_pred)
+        # Propagate prediction to each block
+        ys_pred = ys_pred[self.ids]
+        # The RMSE
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        square_error = np.square(ys_pred - ys)
+        rmse = np.sqrt(square_error.mean())
+        return "p-rmse", rmse
+
+    def average_peak_score(
+        self,
+        ys_pred: np.ndarray,
+        n: int,
+    ) -> Tuple[str, float]:
+        """Evaluate average-peak-score@N in the pack-sum format
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The raw prediction
+        n : int
+            The N in average-peak-score@N
+
+        Returns
+        -------
+        name: str
+            The name of the metric
+        score: float
+            The score of the metric
+        """
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        ys = self.predict_with_score(ys)  # type: ignore # pylint: disable=invalid-name
+        ys = ys / np.unique(self.ids, return_counts=True)[1]  # type: ignore # pylint: disable=invalid-name
+        ys_pred = self.predict_with_score(ys_pred)
+        trials = np.argsort(ys_pred)[::-1][:n]
+        trial_scores = ys[trials]
+        curve = max_curve(trial_scores) / np.max(ys)
+        score = np.mean(curve)
+        return f"a-peak@{n}", score
+
+
+class XGBConfig(NamedTuple):
+    """XGBoost model configuration
+
+    Parameters
+    ----------
+    max_depth : int
+        The maximum depth.
+    gamma : float
+        The gamma.
+    min_child_weight : float
+        The minimum child weight.
+    eta : float
+        The eta, learning rate.
+    seed : int
+        The random seed.
+    nthread : Optional[int],
+        The number of threads to use.
+        Default is None, which means to use physical number of cores.
+    """
+
+    def to_dict(self):
+        xgb_params = {
+            "max_depth": self.max_depth,
+            "gamma": self.gamma,
+            "min_child_weight": self.min_child_weight,
+            "eta": self.eta,
+            "seed": self.seed,
+            "nthread": self.nthread,
+        }
+        return xgb_params
+
+    max_depth: int = 10
+    gamma: float = 0.001
+    min_child_weight: float = 0
+    eta: float = 0.2
+    seed: int = 43
+    nthread: Optional[int] = None
+
+
+class XGBModel(PyCostModel):
+    """XGBoost model
+
+    Parameters
+    ----------
+    extractor : FeatureExtractor
+        The feature extractor for the model.
+    config : XGBConfig
+        The XGBoost model config.
+    num_warmup_samples : int
+        The number of samples that are used for warmup, i.e., the first few samples are predicted
+        with random results.
+    early_stopping_rounds : int
+        The number of rounds for early stopping.
+    verbose_eval : int
+        The verbose level when doing evaluation.
+    average_peak_n : int
+        The number to calculate average peak score.
+    """
+
+    # feature extractor
+    extractor: FeatureExtractor
+    # xgboost model config
+    config: XGBConfig
+    # behavior of randomness
+    num_warmup_samples: int
+    # evaluation
+    early_stopping_rounds: int
+    verbose_eval: int
+    average_peak_n: int
+    # states
+    cached_features: List[np.ndarray]
+    cached_mean_costs: np.ndarray
+    cached_normalizer: Optional[float]
+    booster: Optional["xgb.Booster"]
+
+    def __init__(
+        self,
+        *,
+        # feature extractor
+        extractor: FeatureExtractor,
+        # xgboost model config
+        config: XGBConfig = XGBConfig(),
+        # behavior of randomness
+        num_warmup_samples: int = 100,
+        # evaluation
+        early_stopping_rounds: int = 50,
+        verbose_eval: int = 25,
+        average_peak_n: int = 32,
+    ):
+        super().__init__()
+        # feature extractor
+        self.extractor = extractor
+        # model-related
+        if config.nthread is None:
+            # use physical core number
+            config = config._replace(nthread=cpu_count(logical=False))
+        self.config = config
+        # behavior of randomness
+        self.num_warmup_samples = num_warmup_samples
+        # evaluation
+        self.early_stopping_rounds = early_stopping_rounds
+        self.verbose_eval = verbose_eval
+        self.average_peak_n = average_peak_n
+        # states
+        self.cached_features = []
+        self.cached_mean_costs = np.empty((0,), dtype="float64")
+        self.cached_normalizer = None
+        self.booster = None
+
+    def load(self, path: str) -> None:
+        """Load the cost model from given file location.
+
+        Parameters
+        ----------
+        path : str
+            The file path.
+
+        Note
+        ----
+        Since XGBoost model trains from scratch, each time we can only load the model without the
+        previous cached features / results so any call of update won't use previous training data.

Review comment:
       Thanks for point this out. This comment was made before we store the cached features so please feel free to remove it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9859: [MetaSchedule] XGB-based Cost Model

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9859:
URL: https://github.com/apache/tvm/pull/9859#issuecomment-1006999426


   @comaniac Thanks for the extremely valuable feedback!
   
   > when training data gets bigger and bigger, the time to train the XGBoost cost model becomes tedious even the accuracy isn't further improved
   
   That's exactly what I'm observing too! In this particular case, hyper-parameters of XGB might not be suitable any more, which limits the model capacity, and we might have to tweak around to find out the best hyperparameters.
   
   > What Ansor has done is simply reduce the re-training frequency (e.g., re-train per 2 rounds) when training data size is larger than a threshold.
   
   This is how Ansor deals with this right now...We might consider better heuristics in the future, including switching models, tweaking model capacity with AutoML stuff, etc.
   
   > we can also refer to the accuracy between the predicted cost and new measured latencies to determine whether to re-train the model in the next round
   
   Using our current interface, this is pretty simple to do so. We have a `validate` method that allows us to validate the rmse of the cost model's prediction - and I used this method quite frequently in model debugging too.
   
   
   Anyway, I think we are pretty aligned with the methodology and path to improvement. Let's work together to improve it in the future


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 merged pull request #9859: [MetaSchedule] XGB-based Cost Model

Posted by GitBox <gi...@apache.org>.
junrushao1994 merged pull request #9859:
URL: https://github.com/apache/tvm/pull/9859


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #9859: [MetaSchedule] XGB-based Cost Model

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #9859:
URL: https://github.com/apache/tvm/pull/9859#discussion_r779898360



##########
File path: python/tvm/meta_schedule/cost_model/xgb_model.py
##########
@@ -0,0 +1,680 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+XGBoost-based cost model
+"""
+from itertools import chain as itertools_chain
+import logging
+import os
+import tempfile
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, TYPE_CHECKING, Tuple
+
+import numpy as np
+
+from ...contrib.tar import tar, untar
+from ..cost_model import PyCostModel
+from ..feature_extractor import FeatureExtractor
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..utils import cpu_count
+from .metric import max_curve
+
+if TYPE_CHECKING:
+    from ..tune_context import TuneContext
+    import xgboost as xgb
+
+
+logger = logging.getLogger(__name__)
+
+
+def make_metric_sorter(focused_metric):
+    """ Make sure the focused metric is the first one. """
+
+    def metric_name_for_sort(name):
+        if focused_metric == name:
+            return "!" + name
+        return name
+
+    def sort_key(key):
+        key, _ = key
+        return metric_name_for_sort(key)
+
+    return sort_key
+
+
+class PackSum:
+    """The pack-sum format
+
+    Parameters
+    ----------
+    dmatrix : xgb.DMatrix
+        A float64 array of shape [n, m],
+        where `n` is the packed number of blocks,
+        and `m` is the length of feature vector on each block
+    ids : np.ndarray
+        An int64 array of shape [n] containing nonnegative integers,
+        indicating which the index of a sample that a block belongs to
+    """
+
+    dmatrix: "xgb.DMatrix"  # type: ignore # pylint: disable=invalid-name
+    ids: np.ndarray
+
+    def __init__(
+        self,
+        xs: List[np.ndarray],
+        ys: Optional[np.ndarray],
+    ):
+        """Create PackSum format given a batch of samples
+
+        Parameters
+        ----------
+        xs : List[np.ndarray]
+            A batch of input samples
+        ys : Optional[List[float]]
+            A batch of labels. None means no labels available.
+        """
+        import xgboost as xgb  # type: ignore # pylint: disable=import-outside-toplevel
+
+        repeats = [x.shape[0] for x in xs]
+        xs = np.concatenate(xs, axis=0)
+        self.ids = np.concatenate([[i] * repeat for i, repeat in enumerate(repeats)], axis=0)
+        if ys is None:
+            self.dmatrix = xgb.DMatrix(data=xs, label=None)
+        else:
+            ys = np.concatenate([[y] * repeat for y, repeat in zip(ys, repeats)], axis=0)
+            self.dmatrix = xgb.DMatrix(data=xs, label=ys)
+            self.dmatrix.set_weight(ys)
+
+    def predict_with_score(self, pred: np.ndarray) -> np.ndarray:
+        """Predict the labels given the block level prediction scores.
+
+        Parameters
+        ----------
+        pred : np.ndarray
+            The block level predictions
+
+        Returns
+        -------
+        result : np.ndarray
+            The predictions for each candidate.
+        """
+        return np.bincount(self.ids, weights=pred)
+
+    def obj_square_error(self, ys_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+        """Implement square error loss on pack-sum format as
+        a custom objective function for xgboost.
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The predictions
+
+        Returns
+        -------
+        gradient: np.ndarray
+            The gradient according to the xgboost format
+        hessian: np.ndarray
+            The hessian according to the xgboost format
+        """
+        # Making prediction
+        ys_pred = self.predict_with_score(ys_pred)
+        # Propagate prediction to each block
+        ys_pred = ys_pred[self.ids]
+        # The gradient and hessian
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        gradient = ys_pred - ys
+        hessian = np.ones_like(gradient)
+        return gradient * ys, hessian * ys
+
+    def rmse(self, ys_pred: np.ndarray) -> Tuple[str, float]:
+        """Evaluate RMSE (rooted mean square error) in the pack-sum format
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The raw predictions
+
+        Returns
+        -------
+        name: str
+            The name of the metric
+        score: float
+            The score of the metric
+        """
+        # Making prediction
+        ys_pred = self.predict_with_score(ys_pred)
+        # Propagate prediction to each block
+        ys_pred = ys_pred[self.ids]
+        # The RMSE
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        square_error = np.square(ys_pred - ys)
+        rmse = np.sqrt(square_error.mean())
+        return "p-rmse", rmse
+
+    def average_peak_score(
+        self,
+        ys_pred: np.ndarray,
+        n: int,
+    ) -> Tuple[str, float]:
+        """Evaluate average-peak-score@N in the pack-sum format
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The raw prediction
+        n : int
+            The N in average-peak-score@N
+
+        Returns
+        -------
+        name: str
+            The name of the metric
+        score: float
+            The score of the metric
+        """
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        ys = self.predict_with_score(ys)  # type: ignore # pylint: disable=invalid-name
+        ys = ys / np.unique(self.ids, return_counts=True)[1]  # type: ignore # pylint: disable=invalid-name
+        ys_pred = self.predict_with_score(ys_pred)
+        trials = np.argsort(ys_pred)[::-1][:n]
+        trial_scores = ys[trials]
+        curve = max_curve(trial_scores) / np.max(ys)
+        score = np.mean(curve)
+        return f"a-peak@{n}", score
+
+
+class XGBConfig(NamedTuple):
+    """XGBoost model configuration
+
+    Parameters
+    ----------
+    max_depth : int
+        The maximum depth.
+    gamma : float
+        The gamma.
+    min_child_weight : float
+        The minimum child weight.
+    eta : float
+        The eta, learning rate.
+    seed : int
+        The random seed.
+    nthread : Optional[int],
+        The number of threads to use.
+        Default is None, which means to use physical number of cores.
+    """
+
+    def to_dict(self):
+        xgb_params = {
+            "max_depth": self.max_depth,
+            "gamma": self.gamma,
+            "min_child_weight": self.min_child_weight,
+            "eta": self.eta,
+            "seed": self.seed,
+            "nthread": self.nthread,
+        }
+        return xgb_params
+
+    max_depth: int = 10
+    gamma: float = 0.001
+    min_child_weight: float = 0
+    eta: float = 0.2
+    seed: int = 43
+    nthread: Optional[int] = None
+
+
+class XGBModel(PyCostModel):
+    """XGBoost model
+
+    Parameters
+    ----------
+    extractor : FeatureExtractor
+        The feature extractor for the model.
+    config : XGBConfig
+        The XGBoost model config.
+    num_warmup_samples : int
+        The number of samples that are used for warmup, i.e., the first few samples are predicted
+        with random results.
+    early_stopping_rounds : int
+        The number of rounds for early stopping.
+    verbose_eval : int
+        The verbose level when doing evaluation.
+    average_peak_n : int
+        The number to calculate average peak score.
+    """
+
+    # feature extractor
+    extractor: FeatureExtractor
+    # xgboost model config
+    config: XGBConfig
+    # behavior of randomness
+    num_warmup_samples: int
+    # evaluation
+    early_stopping_rounds: int
+    verbose_eval: int
+    average_peak_n: int
+    # states
+    cached_features: List[np.ndarray]
+    cached_mean_costs: np.ndarray
+    cached_normalizer: Optional[float]
+    booster: Optional["xgb.Booster"]
+
+    def __init__(
+        self,
+        *,
+        # feature extractor
+        extractor: FeatureExtractor,
+        # xgboost model config
+        config: XGBConfig = XGBConfig(),
+        # behavior of randomness
+        num_warmup_samples: int = 100,
+        # evaluation
+        early_stopping_rounds: int = 50,
+        verbose_eval: int = 25,
+        average_peak_n: int = 32,
+    ):
+        super().__init__()
+        # feature extractor
+        self.extractor = extractor
+        # model-related
+        if config.nthread is None:
+            # use physical core number
+            config = config._replace(nthread=cpu_count(logical=False))
+        self.config = config
+        # behavior of randomness
+        self.num_warmup_samples = num_warmup_samples
+        # evaluation
+        self.early_stopping_rounds = early_stopping_rounds
+        self.verbose_eval = verbose_eval
+        self.average_peak_n = average_peak_n
+        # states
+        self.cached_features = []
+        self.cached_mean_costs = np.empty((0,), dtype="float64")
+        self.cached_normalizer = None
+        self.booster = None
+
+    def load(self, path: str) -> None:
+        """Load the cost model from given file location.
+
+        Parameters
+        ----------
+        path : str
+            The file path.
+
+        Note
+        ----
+        Since XGBoost model trains from scratch, each time we can only load the model without the
+        previous cached features / results so any call of update won't use previous training data.

Review comment:
       I didn't quite get this note. It seems to me that you load the previous model along with cached features and costs. Doesn't the loaded cached features and costs used in future training?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 merged pull request #9859: [MetaSchedule] XGB-based Cost Model

Posted by GitBox <gi...@apache.org>.
junrushao1994 merged pull request #9859:
URL: https://github.com/apache/tvm/pull/9859


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9859: [MetaSchedule] XGB-based Cost Model

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9859:
URL: https://github.com/apache/tvm/pull/9859#discussion_r779914185



##########
File path: python/tvm/meta_schedule/cost_model/xgb_model.py
##########
@@ -0,0 +1,680 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+XGBoost-based cost model
+"""
+from itertools import chain as itertools_chain
+import logging
+import os
+import tempfile
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, TYPE_CHECKING, Tuple
+
+import numpy as np
+
+from ...contrib.tar import tar, untar
+from ..cost_model import PyCostModel
+from ..feature_extractor import FeatureExtractor
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..utils import cpu_count
+from .metric import max_curve
+
+if TYPE_CHECKING:
+    from ..tune_context import TuneContext
+    import xgboost as xgb
+
+
+logger = logging.getLogger(__name__)
+
+
+def make_metric_sorter(focused_metric):
+    """ Make sure the focused metric is the first one. """
+
+    def metric_name_for_sort(name):
+        if focused_metric == name:
+            return "!" + name
+        return name
+
+    def sort_key(key):
+        key, _ = key
+        return metric_name_for_sort(key)
+
+    return sort_key
+
+
+class PackSum:
+    """The pack-sum format
+
+    Parameters
+    ----------
+    dmatrix : xgb.DMatrix
+        A float64 array of shape [n, m],
+        where `n` is the packed number of blocks,
+        and `m` is the length of feature vector on each block
+    ids : np.ndarray
+        An int64 array of shape [n] containing nonnegative integers,
+        indicating which the index of a sample that a block belongs to
+    """
+
+    dmatrix: "xgb.DMatrix"  # type: ignore # pylint: disable=invalid-name
+    ids: np.ndarray
+
+    def __init__(
+        self,
+        xs: List[np.ndarray],
+        ys: Optional[np.ndarray],
+    ):
+        """Create PackSum format given a batch of samples
+
+        Parameters
+        ----------
+        xs : List[np.ndarray]
+            A batch of input samples
+        ys : Optional[List[float]]
+            A batch of labels. None means no labels available.
+        """
+        import xgboost as xgb  # type: ignore # pylint: disable=import-outside-toplevel
+
+        repeats = [x.shape[0] for x in xs]
+        xs = np.concatenate(xs, axis=0)
+        self.ids = np.concatenate([[i] * repeat for i, repeat in enumerate(repeats)], axis=0)
+        if ys is None:
+            self.dmatrix = xgb.DMatrix(data=xs, label=None)
+        else:
+            ys = np.concatenate([[y] * repeat for y, repeat in zip(ys, repeats)], axis=0)
+            self.dmatrix = xgb.DMatrix(data=xs, label=ys)
+            self.dmatrix.set_weight(ys)
+
+    def predict_with_score(self, pred: np.ndarray) -> np.ndarray:
+        """Predict the labels given the block level prediction scores.
+
+        Parameters
+        ----------
+        pred : np.ndarray
+            The block level predictions
+
+        Returns
+        -------
+        result : np.ndarray
+            The predictions for each candidate.
+        """
+        return np.bincount(self.ids, weights=pred)
+
+    def obj_square_error(self, ys_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+        """Implement square error loss on pack-sum format as
+        a custom objective function for xgboost.
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The predictions
+
+        Returns
+        -------
+        gradient: np.ndarray
+            The gradient according to the xgboost format
+        hessian: np.ndarray
+            The hessian according to the xgboost format
+        """
+        # Making prediction
+        ys_pred = self.predict_with_score(ys_pred)
+        # Propagate prediction to each block
+        ys_pred = ys_pred[self.ids]
+        # The gradient and hessian
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        gradient = ys_pred - ys
+        hessian = np.ones_like(gradient)
+        return gradient * ys, hessian * ys
+
+    def rmse(self, ys_pred: np.ndarray) -> Tuple[str, float]:
+        """Evaluate RMSE (rooted mean square error) in the pack-sum format
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The raw predictions
+
+        Returns
+        -------
+        name: str
+            The name of the metric
+        score: float
+            The score of the metric
+        """
+        # Making prediction
+        ys_pred = self.predict_with_score(ys_pred)
+        # Propagate prediction to each block
+        ys_pred = ys_pred[self.ids]
+        # The RMSE
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        square_error = np.square(ys_pred - ys)
+        rmse = np.sqrt(square_error.mean())
+        return "p-rmse", rmse
+
+    def average_peak_score(
+        self,
+        ys_pred: np.ndarray,
+        n: int,
+    ) -> Tuple[str, float]:
+        """Evaluate average-peak-score@N in the pack-sum format
+
+        Parameters
+        ----------
+        ys_pred: np.ndarray
+            The raw prediction
+        n : int
+            The N in average-peak-score@N
+
+        Returns
+        -------
+        name: str
+            The name of the metric
+        score: float
+            The score of the metric
+        """
+        ys = self.dmatrix.get_label()  # type: ignore # pylint: disable=invalid-name
+        ys = self.predict_with_score(ys)  # type: ignore # pylint: disable=invalid-name
+        ys = ys / np.unique(self.ids, return_counts=True)[1]  # type: ignore # pylint: disable=invalid-name
+        ys_pred = self.predict_with_score(ys_pred)
+        trials = np.argsort(ys_pred)[::-1][:n]
+        trial_scores = ys[trials]
+        curve = max_curve(trial_scores) / np.max(ys)
+        score = np.mean(curve)
+        return f"a-peak@{n}", score
+
+
+class XGBConfig(NamedTuple):
+    """XGBoost model configuration
+
+    Parameters
+    ----------
+    max_depth : int
+        The maximum depth.
+    gamma : float
+        The gamma.
+    min_child_weight : float
+        The minimum child weight.
+    eta : float
+        The eta, learning rate.
+    seed : int
+        The random seed.
+    nthread : Optional[int],
+        The number of threads to use.
+        Default is None, which means to use physical number of cores.
+    """
+
+    def to_dict(self):
+        xgb_params = {
+            "max_depth": self.max_depth,
+            "gamma": self.gamma,
+            "min_child_weight": self.min_child_weight,
+            "eta": self.eta,
+            "seed": self.seed,
+            "nthread": self.nthread,
+        }
+        return xgb_params
+
+    max_depth: int = 10
+    gamma: float = 0.001
+    min_child_weight: float = 0
+    eta: float = 0.2
+    seed: int = 43
+    nthread: Optional[int] = None
+
+
+class XGBModel(PyCostModel):
+    """XGBoost model
+
+    Parameters
+    ----------
+    extractor : FeatureExtractor
+        The feature extractor for the model.
+    config : XGBConfig
+        The XGBoost model config.
+    num_warmup_samples : int
+        The number of samples that are used for warmup, i.e., the first few samples are predicted
+        with random results.
+    early_stopping_rounds : int
+        The number of rounds for early stopping.
+    verbose_eval : int
+        The verbose level when doing evaluation.
+    average_peak_n : int
+        The number to calculate average peak score.
+    """
+
+    # feature extractor
+    extractor: FeatureExtractor
+    # xgboost model config
+    config: XGBConfig
+    # behavior of randomness
+    num_warmup_samples: int
+    # evaluation
+    early_stopping_rounds: int
+    verbose_eval: int
+    average_peak_n: int
+    # states
+    cached_features: List[np.ndarray]
+    cached_mean_costs: np.ndarray
+    cached_normalizer: Optional[float]
+    booster: Optional["xgb.Booster"]
+
+    def __init__(
+        self,
+        *,
+        # feature extractor
+        extractor: FeatureExtractor,
+        # xgboost model config
+        config: XGBConfig = XGBConfig(),
+        # behavior of randomness
+        num_warmup_samples: int = 100,
+        # evaluation
+        early_stopping_rounds: int = 50,
+        verbose_eval: int = 25,
+        average_peak_n: int = 32,
+    ):
+        super().__init__()
+        # feature extractor
+        self.extractor = extractor
+        # model-related
+        if config.nthread is None:
+            # use physical core number
+            config = config._replace(nthread=cpu_count(logical=False))
+        self.config = config
+        # behavior of randomness
+        self.num_warmup_samples = num_warmup_samples
+        # evaluation
+        self.early_stopping_rounds = early_stopping_rounds
+        self.verbose_eval = verbose_eval
+        self.average_peak_n = average_peak_n
+        # states
+        self.cached_features = []
+        self.cached_mean_costs = np.empty((0,), dtype="float64")
+        self.cached_normalizer = None
+        self.booster = None
+
+    def load(self, path: str) -> None:
+        """Load the cost model from given file location.
+
+        Parameters
+        ----------
+        path : str
+            The file path.
+
+        Note
+        ----
+        Since XGBoost model trains from scratch, each time we can only load the model without the
+        previous cached features / results so any call of update won't use previous training data.

Review comment:
       Updated with accurate documentation




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org