You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/11 11:32:22 UTC
[tvm] branch main updated: [MetaSchedule] Added a cost model (#11961)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9ee25eb9f4 [MetaSchedule] Added a cost model (#11961)
9ee25eb9f4 is described below
commit 9ee25eb9f45f27d52fdc308dbd8970e5f095fef6
Author: Kathryn (Jinqi) Chen <65...@users.noreply.github.com>
AuthorDate: Mon Jul 11 04:32:13 2022 -0700
[MetaSchedule] Added a cost model (#11961)
In this PR, I added a cost model based on SegmentSum MLP, which can be used for pre-training or integration with TVM.
---
python/tvm/meta_schedule/cost_model/cost_model.py | 2 +-
python/tvm/meta_schedule/cost_model/mlp_model.py | 1010 +++++++++++++++++++++
src/meta_schedule/database/json_database.cc | 2 +-
3 files changed, 1012 insertions(+), 2 deletions(-)
diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py
index 2fdb9b9349..d3b660d837 100644
--- a/python/tvm/meta_schedule/cost_model/cost_model.py
+++ b/python/tvm/meta_schedule/cost_model/cost_model.py
@@ -190,7 +190,7 @@ class PyCostModel:
raise NotImplementedError
def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
- """Update the cost model given running results.
+ """Predict given the measure candidates.
Parameters
----------
diff --git a/python/tvm/meta_schedule/cost_model/mlp_model.py b/python/tvm/meta_schedule/cost_model/mlp_model.py
new file mode 100644
index 0000000000..04ccca0563
--- /dev/null
+++ b/python/tvm/meta_schedule/cost_model/mlp_model.py
@@ -0,0 +1,1010 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# type: ignore[import]
+"""
+Segment Sum MLP cost model
+"""
+import glob
+import logging
+import math
+import os
+import random
+import tempfile
+from collections import OrderedDict
+from itertools import chain as itertools_chain
+from typing import Dict, List, NamedTuple, Tuple
+
+import numpy as np # type: ignore
+import torch # type: ignore
+import tvm
+
+from ...contrib.tar import tar, untar
+from ...runtime import NDArray
+from ...target import Target
+from ..cost_model import PyCostModel
+from ..database import JSONDatabase
+from ..feature_extractor import FeatureExtractor, PerStoreFeature
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..tune_context import TuneContext
+from ..utils import derived_object, shash2hex
+
+logging.basicConfig()
+logger = logging.getLogger("mlp_model") # pylint: disable=invalid-name
+logger.setLevel(logging.INFO)
+
+# pylint: disable=no-member,import-outside-toplevel
+
+
+class SegmentSumMLPConfig(NamedTuple):
+ """SegmentSum MLP model configuration
+
+ Parameters
+ ----------
+ input_dim : int
+ The input dim for the model.
+ hidden_dim : int
+ The hidden dim for the model.
+ output_dim : int
+ The output dim for the model.
+ use_norm : bool
+ Whether to normalize the segment sum or not.
+ use_sigmoid : bool
+ Whether to use sigmoid on the final output or not.
+ """
+
+ input_dim: int = 172
+ hidden_dim: int = 256
+ output_dim: int = 1
+ use_norm: bool = False
+ use_sigmoid: bool = False
+
+ def to_dict(self): # pylint: disable=missing-function-docstring
+ return {
+ "input_dim": self.input_dim,
+ "hidden_dim": self.hidden_dim,
+ "output_dim": self.output_dim,
+ "use_norm": self.use_norm,
+ "use_sigmoid": self.use_sigmoid,
+ }
+
+
+class TrainerConfig(NamedTuple):
+ """Trainer configuration
+
+ Parameters
+ ----------
+ batch_size : int
+ The batch size.
+ learning rate : float
+ The learning rate.
+ weight decay : float
+ The weight decay.
+ num_epoch_full : int
+ The number of epochs used in full training.
+ num_epoch_incremental : int
+ The number of epochs used in incremental training.
+ grad_clip_norm: float
+ The norm of gradient clipping.
+ train_verbose: int
+ The verbose frequency for training in batches.
+ test_interval: int
+ The testing interval in epochs.
+ test_split: float
+ The fraction of data for testing.
+ frozen: bool
+ Determine whether to re-train the model or not.
+ """
+
+ batch_size: int = 128
+ learning_rate: float = 7e-4
+ weight_decay: float = 1e-6
+ num_epoch_full: int = 50
+ num_epoch_incremental: int = 5
+ grad_clip_norm: float = 0.5
+ train_verbose: int = 1000
+ test_interval: int = 1
+ test_split: float = 0.2
+ frozen: bool = False
+
+ def to_dict(self): # pylint: disable=missing-function-docstring
+ return {
+ "batch_size": self.batch_size,
+ "learning_rate": self.learning_rate,
+ "weight_decay": self.weight_decay,
+ "num_epoch_full": self.num_epoch_full,
+ "num_epoch_incremental": self.num_epoch_incremental,
+ "grad_clip_norm": self.grad_clip_norm,
+ "train_verbose": self.train_verbose,
+ "test_interval": self.test_interval,
+ "test_split": self.test_split,
+ "frozen": self.frozen,
+ }
+
+
+# pylint: disable=too-few-public-methods
+class FeatureGroup:
+ """Feature group
+
+ Parameters
+ ----------
+ group_hash : str
+ The hash of the group
+ features : List[np.ndarray]
+ The features
+ costs : List[float]
+ The costs
+ min_cost : float
+ The minimum cost
+ """
+
+ group_hash: str
+ features: List[np.ndarray]
+ costs: np.ndarray
+ min_cost: float
+
+ def __init__(
+ self,
+ group_hash: str,
+ features: List[np.ndarray],
+ costs: np.ndarray,
+ ) -> None:
+ self.group_hash = group_hash
+ self.features = features
+ self.costs = costs
+ self.min_cost = np.min(costs)
+
+ def append( # pylint: disable=missing-function-docstring
+ self,
+ features: List[np.ndarray],
+ costs: np.ndarray,
+ ) -> None:
+ self.features.extend(features)
+ self.costs = np.append(self.costs, costs)
+ self.min_cost = np.min(self.costs)
+
+
+# pylint: disable=too-many-instance-attributes
+class SegmentDataLoader:
+ """Dataloader for Segment Sum MLP model.
+
+ Parameters
+ ----------
+ features : List[np.ndarray]
+ The features
+ results : np.ndarray
+ The measured results, can be None.
+ batch_size : int
+ The batch size
+ shuffle : bool
+ Whether to shuffle the dataset or not
+ """
+
+ def __init__(
+ self,
+ features,
+ results=None,
+ batch_size=128,
+ shuffle=True,
+ ):
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.data_size = len(features)
+
+ # flatten features and store the starting indices
+ self.segment_sizes = torch.tensor([len(feature) for feature in features], dtype=torch.int32)
+ self.feature_offsets = (
+ torch.cumsum(self.segment_sizes, 0, dtype=torch.int32) - self.segment_sizes
+ )
+ features = torch.cat([torch.tensor(feature) for feature in features])
+ norm, _ = features.max(dim=0)
+ norm[norm == 0] = 1
+ self.features = features / norm
+ self.results = torch.tensor(results) if results is not None else None
+ self.iter_order = self.pointer = None
+
+ def __len__(self):
+ return self.data_size
+
+ def __iter__(self):
+ if self.shuffle:
+ self.iter_order = torch.randperm(self.data_size)
+ else:
+ self.iter_order = torch.arange(self.data_size)
+ self.pointer = 0
+ return self
+
+ def __next__(self):
+ if self.pointer >= self.data_size:
+ raise StopIteration
+ batch_indices = self.iter_order[self.pointer : self.pointer + self.batch_size]
+ self.pointer += self.batch_size
+ return self._fetch_indices(batch_indices)
+
+ def _fetch_indices(self, indices):
+ segment_sizes, feature_offsets = self.segment_sizes[indices], self.feature_offsets[indices]
+ feature_indices = torch.empty(segment_sizes.sum(), dtype=torch.int32)
+ idx = 0
+ for offset, seg_size in zip(feature_offsets, segment_sizes):
+ feature_indices[idx : idx + seg_size] = torch.arange(offset, offset + seg_size)
+ idx += seg_size
+ features = self.features[feature_indices.long()]
+ results = None
+ if self.results is not None:
+ results = self.results[indices.long()]
+ return segment_sizes, features, results
+
+
+def lambda_rank_loss( # pylint: disable=too-many-locals
+ preds: "torch.Tensor",
+ labels: "torch.Tensor",
+ k: int = None,
+ eps: float = 1e-10,
+ sigma: float = 1.0,
+) -> "torch.Tensor":
+ """
+ LambdaLoss: Metric-Driven Loss for Learning-to-Rank
+
+ Parameters
+ ----------
+ preds : Tensor
+ The predicted runtime for each candidate.
+ labels : Tensor
+ The measured runtime for each candidate.
+ k : int
+ Loss for top k.
+ Default is None, which means computing all scores.
+ eps : float
+ The minimum value to the denominator and argument of log if they reach 0.
+ sigma : float
+ The scaling factor to the input of the sigmoid function.
+
+ Returns
+ -------
+ loss : Tensor
+ The lambda rank loss.
+ """
+ device = preds.device
+ y_pred, y_true = preds[None, :], labels[None, :]
+ y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
+ y_true_sorted, _ = y_true.sort(descending=True, dim=-1)
+ true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
+ true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
+ padded_pairs_mask = torch.isfinite(true_diffs) & (true_diffs > 0)
+ ndcg_at_k_mask = torch.zeros(
+ (y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device
+ )
+ ndcg_at_k_mask[:k, :k] = 1
+ true_sorted_by_preds.clamp_(min=0.0)
+ y_true_sorted.clamp_(min=0.0)
+ pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
+ D = torch.log2(1.0 + pos_idxs.float())[None, :] # pylint: disable=invalid-name
+ maxDCGs = torch.sum( # pylint: disable=invalid-name
+ ((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1
+ ).clamp(min=eps)
+ G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None] # pylint: disable=invalid-name
+ weights = torch.abs(
+ torch.pow(D[:, :, None], -1.0) - torch.pow(D[:, None, :], -1.0)
+ ) * torch.abs(G[:, :, None] - G[:, None, :])
+ scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8)
+ scores_diffs[torch.isnan(scores_diffs)] = 0.0
+ weighted_probs = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps)
+ losses = torch.log2(weighted_probs)
+ masked_losses = losses[padded_pairs_mask & ndcg_at_k_mask]
+ loss = -torch.sum(masked_losses)
+ return loss
+
+
+def topk_score(
+ pred_results: "torch.Tensor",
+ gt_results: "torch.Tensor",
+ k: int,
+) -> float:
+ """
+ Evaluate the top-k score
+
+ Parameters
+ ----------
+ pred_results: Tensor
+ The raw prediction
+ gt_results: Tensor
+ The measured results
+ k : int
+ The k in top k score
+
+ Returns
+ -------
+ score : float
+ The top-k score
+ """
+ k = min(k, len(pred_results))
+ topk_indices = torch.topk(pred_results, k, largest=False).indices
+ score = gt_results.min() / gt_results[topk_indices].min()
+ return score.item()
+
+
+class SegmentSumMLP(torch.nn.Module):
+ """Segment Sum MLP model.
+
+ Parameters
+ ----------
+ input_dim : int
+ The input dim for the model.
+ hidden_dim : int
+ The hidden dim for the model.
+ output_dim : int
+ The output dim for the model.
+ use_norm : bool
+ Whether to normalize the segment sum or not.
+ use_sigmoid : bool
+ Whether to use sigmoid on the final output or not.
+ """
+
+ input_dim: int
+ hidden_dim: int
+ output_dim: int
+ use_norm: bool
+ use_sigmoid: bool
+
+ def __init__( # pylint: disable=too-many-arguments
+ self,
+ input_dim: int = 172,
+ hidden_dim: int = 256,
+ output_dim: int = 1,
+ use_norm: bool = False,
+ use_sigmoid: bool = False,
+ ):
+ from torch import nn # type: ignore
+
+ super().__init__()
+ self.encoder = nn.Sequential(
+ nn.Linear(input_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ )
+ self.norm = nn.BatchNorm1d(hidden_dim) if use_norm else nn.Identity()
+ self.layer0 = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ )
+ self.layer1 = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ )
+ self.decoder = nn.Linear(hidden_dim, output_dim)
+ self.sigmoid = nn.Sigmoid() if use_sigmoid else nn.Identity()
+
+ def forward( # pylint: disable=missing-function-docstring
+ self,
+ segment_sizes: "torch.Tensor",
+ features: "torch.Tensor",
+ ) -> "torch.Tensor":
+ n_seg = len(segment_sizes)
+ encoded_features = self.encoder(features)
+ segment_indices = torch.repeat_interleave(
+ torch.arange(n_seg, device=features.device),
+ segment_sizes.long(),
+ )
+ n_dim = encoded_features.shape[1]
+ segment_sum = torch.scatter_add(
+ input=torch.zeros((n_seg, n_dim), dtype=encoded_features.dtype, device=features.device),
+ dim=0,
+ index=segment_indices.view(-1, 1).expand(-1, n_dim),
+ src=encoded_features,
+ )
+ out = self.norm(segment_sum)
+ out = self.layer0(out) + out
+ out = self.layer1(out) + out
+ out = self.decoder(out).squeeze()
+ out = self.sigmoid(out)
+ return out
+
+
+def extract_features(
+ context: TuneContext,
+ candidates: List[MeasureCandidate],
+ results: List[RunnerResult] = None,
+ extractor: FeatureExtractor = PerStoreFeature(extract_workload=True),
+):
+ """Extract feature vectors and compute mean costs.
+
+ Parameters
+ ----------
+ context: TuneContext
+ The tuning context.
+ candidates: List[MeasureCandidate]
+ The measure candidates.
+ results: List[RunnerResult]
+ The measured results, can be None if used in prediction.
+ extractor: FeatureExtractor
+ The feature extractor.
+
+ Returns
+ -------
+ new_features: List[np.ndarray]
+ The extracted features.
+ new_mean_costs: np.ndarray
+ The mean costs.
+ """
+
+ def _feature(feature: NDArray) -> np.ndarray:
+ return feature.numpy().astype("float32")
+
+ def _mean_cost(res: RunnerResult) -> float:
+ if not res.run_secs:
+ return 1e10
+ return float(np.median([float(s) for s in res.run_secs]))
+
+ new_features = [_feature(x) for x in extractor.extract_from(context, candidates)]
+ new_mean_costs = (
+ np.array([_mean_cost(x) for x in results]).astype("float32")
+ if results is not None
+ else None
+ )
+ return new_features, new_mean_costs
+
+
+class State:
+ """State of the trainer
+
+ Parameters
+ ----------
+ model: SegmentSumMLP
+ The cost model.
+ data: Dict[str, FeatureGroup]
+ The data groups.
+ data_size: int
+ The size of all data.
+ untrained_size: int
+ The size of the untrained data.
+ """
+
+ model: SegmentSumMLP
+ data: Dict[str, FeatureGroup]
+ data_size: int
+ untrained_size: int
+
+ def __init__(
+ self,
+ model_config: SegmentSumMLPConfig = SegmentSumMLPConfig(),
+ extractor: FeatureExtractor = PerStoreFeature(extract_workload=True),
+ ):
+ self.model = SegmentSumMLP(**model_config.to_dict())
+ self.data = OrderedDict()
+ self.data_size = 0
+ self.untrained_size = 0
+ self.extractor = extractor
+
+ def load( # pylint: disable=too-many-locals
+ self,
+ path: str,
+ target: str = "nvidia/nvidia-v100",
+ ) -> None:
+ """Load the cached model, cached features, or raw data.
+
+ Parameters
+ ----------
+ path: str
+ The path to the tar file containing cached model, cached features,
+ or raw data.
+ target: str
+ The target for the tuning context.
+ """
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model_path = os.path.join(tmp_dir, "model.pth")
+ cache_path = os.path.join(tmp_dir, "cached_data.npy")
+ raw_path = os.path.join(tmp_dir, "raw_data")
+ untar(path, tmp_dir)
+ if os.path.exists(model_path):
+ self.model.load_state_dict(torch.load(model_path))
+ if os.path.exists(cache_path):
+ for group_hash, features, costs in np.load(cache_path, allow_pickle=True):
+ self.data[group_hash] = FeatureGroup(
+ group_hash=group_hash,
+ features=list(features),
+ costs=costs,
+ )
+ self.data_size += len(costs)
+ self.untrained_size += len(costs)
+ elif os.path.exists(raw_path):
+ from tqdm import tqdm # type: ignore
+
+ model_dirs = glob.glob(os.path.join(raw_path, "*"))
+ workload_paths = []
+ for model_dir in model_dirs:
+ json_files = glob.glob(os.path.join(model_dir, "*.json"))
+ for json_file in json_files:
+ if json_file.endswith("_workload.json"):
+ workload_paths.append(json_file)
+ for workload_path in tqdm(workload_paths):
+ try:
+ database = JSONDatabase(
+ path_workload=workload_path,
+ path_tuning_record=workload_path.replace(
+ "_workload.json", "_candidates.json"
+ ),
+ )
+ except tvm._ffi.base.TVMError: # pylint: disable=protected-access
+ continue
+ candidates, results = [], []
+ tuning_records = database.get_all_tuning_records()
+ if len(tuning_records) == 0:
+ continue
+ for record in tuning_records:
+ candidates.append(record.as_measure_candidate())
+ results.append(RunnerResult(run_secs=record.run_secs, error_msg=None))
+ assert len(candidates) == len(results)
+ context = TuneContext(mod=tuning_records[0].workload.mod, target=Target(target))
+ features, mean_costs = extract_features(
+ context, candidates, results, self.extractor
+ )
+ self.add_to_group(features, mean_costs, shash2hex(context.mod))
+
+ def save(self, path: str) -> None:
+ """Cache the model and data.
+
+ Parameters
+ ----------
+ path: str
+ The path to the cached tar file.
+ """
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model_path = os.path.join(tmp_dir, "model.pth")
+ cache_path = os.path.join(tmp_dir, "cached_data.npy")
+ torch.save(self.model.state_dict(), model_path)
+ data = [
+ (
+ g.group_hash,
+ g.features,
+ g.costs,
+ )
+ for g in self.data.values()
+ ]
+ np.save(
+ file=cache_path,
+ arr=np.array(data, dtype=object),
+ )
+ tar(path, [x for x in [model_path, cache_path] if x is not None])
+ logger.info("Saved MLPModel to %s", path)
+
+ def add_to_group(
+ self,
+ features: List[np.ndarray],
+ costs: np.ndarray,
+ group_hash: str,
+ ):
+ """Add features and costs to the data groups with key group_hash.
+
+ Parameters
+ ----------
+ features: List[np.ndarray]
+ The feature vectors.
+ costs: np.ndarray
+ The measured results.
+ group_hash: str
+ The structural hash of the candidates.
+ """
+ group = self.data.get(group_hash, None)
+ if group is None:
+ group = FeatureGroup(
+ group_hash=group_hash,
+ features=features,
+ costs=costs,
+ )
+ else:
+ group.append(features, costs)
+ self.data[group_hash] = group
+ self.data_size += len(features)
+ self.untrained_size += len(features)
+
+
+class SegmentSumMLPTrainer:
+ """The trainer for Segment Sum MLP model.
+
+ Parameters
+ ----------
+ state: State
+ The state of the trainer.
+ batch_size : int
+ The batch size.
+ learning rate : float
+ The learning rate.
+ weight decay : float
+ The weight decay.
+ num_epoch_full : int
+ The number of epochs used in full training.
+ num_epoch_incremental : int
+ The number of epochs used in incremental training.
+ grad_clip_norm: float
+ The norm of gradient clipping.
+ train_verbose: int
+ The verbose frequency for training in batches.
+ test_interval: int
+ The testing interval in epochs.
+ test_split: float
+ The fraction of data for testing.
+ frozen: bool
+ Determine whether to re-train the model or not.
+ optimizer: "torch.optim.adam.Adam"
+ The optimizer.
+ scheduler: "torch.optim.lr_scheduler.StepLR"
+ The scheduler.
+ """
+
+ state: State
+ batch_size: int = 128
+ learning_rate: float = 7e-4
+ weight_decay: float = 1e-6
+ num_epoch_full: int = 50
+ num_epoch_incremental: int = 5
+ grad_clip_norm: float = 0.5
+ train_verbose: int = 1000
+ test_interval: int = 1
+ test_split: float = 0.2
+ frozen: bool = False
+ optimizer: "torch.optim.adam.Adam" # type: ignore
+ scheduler: "torch.optim.lr_scheduler.StepLR" # type: ignore
+
+ def __init__(
+ self,
+ train_config: TrainerConfig = TrainerConfig(),
+ state: State = State(),
+ ):
+ config = train_config.to_dict()
+ for attr in config:
+ setattr(self, attr, config[attr])
+ self.state = state
+ self.device = "cuda" if torch.cuda.device_count() else "cpu"
+ self.optimizer, self.scheduler = None, None
+
+ def train_step(
+ self,
+ data: Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"],
+ batch: int = 0,
+ train_loss: float = None,
+ ) -> float:
+ """Helper function for training on a single batch.
+
+ Parameters
+ ----------
+ data: Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]
+ A batch of data, should be a tuple of (segment_sizes, features, gt_results).
+ batch: int = 0
+ The current batch number.
+ train_loss: float = None
+ The previous averaged training loss, None if it is the first batch.
+
+ Returns
+ -------
+ train_loss: float
+ The averaged training loss after the current batch.
+ """
+ segment_sizes, features, gt_results = (
+ data[0].to(self.device),
+ data[1].to(self.device),
+ data[2].to(self.device),
+ )
+ self.optimizer.zero_grad()
+ pred_results = self.state.model(segment_sizes, features)
+ loss = lambda_rank_loss(pred_results, gt_results)
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(self.state.model.parameters(), self.grad_clip_norm)
+ self.optimizer.step()
+ loss = loss.detach().cpu()
+ train_loss = (
+ train_loss * 0.95 + loss.item() * 0.05 if train_loss is not None else loss.item()
+ )
+ segment_sizes, features, gt_results, pred_results = (
+ segment_sizes.detach().cpu(),
+ features.detach().cpu(),
+ gt_results.detach().cpu(),
+ pred_results.detach().cpu(),
+ )
+ if batch % self.train_verbose == 0:
+ logger.info("Batch: %d, train loss: %6f", batch, train_loss)
+ return train_loss
+
+ def predict_step(
+ self,
+ data: Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"],
+ ):
+ """Helper function for predicting (validating) on a single batch.
+
+ Parameters
+ ----------
+ data: Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]
+ A batch of data, should be a tuple of (segment_sizes, features, gt_results).
+ gt_results can be None if it is used for predicting.
+
+ Returns
+ -------
+ pred_results: np.ndarray
+ The predicted results for the current batch.
+ test_loss_batch: float
+ If used for validation, return the test loss for the current batch.
+ test_scores_batch: List[float]
+ If used for validation, return the topk scores for the current batch.
+ """
+ test_loss_batch, test_scores_batch = None, []
+ segment_sizes, features = (
+ data[0].to(self.device),
+ data[1].to(self.device),
+ )
+ gt_results = data[2]
+ pred_results = self.state.model(segment_sizes, features)
+ segment_sizes, features, pred_results = (
+ segment_sizes.detach().cpu(),
+ features.detach().cpu(),
+ pred_results.detach().cpu(),
+ )
+ if gt_results is not None:
+ test_loss_batch = lambda_rank_loss(pred_results, gt_results).item()
+ for k in [1, 5, 10]:
+ test_scores_batch.append(topk_score(pred_results, gt_results, k))
+ return pred_results.numpy(), test_loss_batch, test_scores_batch
+
+ def train_full(self): # pylint: disable=too-many-locals
+ """Training on the full dataset."""
+ # split into training and testing set
+ keys = list(self.state.data.keys())
+ test_keys = random.sample(keys, k=math.floor(len(keys) * self.test_split))
+ train_data = OrderedDict()
+ test_data = OrderedDict()
+ for key in keys:
+ if key in test_keys:
+ test_data[key] = self.state.data[key]
+ else:
+ train_data[key] = self.state.data[key]
+ train_features = list(
+ itertools_chain.from_iterable([g.features for g in train_data.values()])
+ )
+ test_features = list(
+ itertools_chain.from_iterable([g.features for g in test_data.values()])
+ )
+ train_results = np.concatenate([g.min_cost / g.costs for g in train_data.values()])
+ test_results = np.concatenate([g.min_cost / g.costs for g in test_data.values()])
+ train_loader = SegmentDataLoader(
+ train_features, train_results, batch_size=self.batch_size, shuffle=True
+ )
+ test_loader = SegmentDataLoader(
+ test_features, test_results, batch_size=self.batch_size, shuffle=False
+ )
+ self.optimizer = torch.optim.Adam(
+ self.state.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
+ )
+ self.scheduler = torch.optim.lr_scheduler.StepLR(
+ self.optimizer,
+ step_size=self.num_epoch_full // 10,
+ gamma=0.8,
+ verbose=True,
+ )
+ self.state.model = self.state.model.to(self.device)
+ min_test_loss = 1e10
+ logger.info("Training size: %d; Testing size: %d", len(train_loader), len(test_loader))
+
+ model_cache_path = tempfile.NamedTemporaryFile().name # pylint: disable=consider-using-with
+ for epoch in range(self.num_epoch_full):
+ logger.info("Epoch: %d", epoch)
+ # training
+ self.state.model.train()
+ train_loss = None
+ for batch, data in enumerate(train_loader):
+ train_loss = self.train_step(data, batch, train_loss)
+ self.scheduler.step()
+ # testing
+ if epoch % self.test_interval == 0:
+ self.state.model.eval()
+ test_losses, test_scores = [], []
+ for data in test_loader:
+ _, test_loss_batch, test_scores_batch = self.predict_step(data)
+ test_losses.append(test_loss_batch)
+ test_scores.append(test_scores_batch)
+ test_loss = (
+ np.array(test_losses[:-1]).mean() if len(test_losses) > 1 else test_losses[0]
+ )
+ logger.info(
+ "Average test loss: %6f, top1 score: %5f, top5 score: %5f, top10 score: %5f",
+ test_loss,
+ np.array(test_scores)[:, 0].mean(),
+ np.array(test_scores)[:, 1].mean(),
+ np.array(test_scores)[:, 2].mean(),
+ )
+ if test_loss < min_test_loss:
+ min_test_loss = test_loss
+ torch.save(self.state.model.state_dict(), model_cache_path)
+ self.state.model.to("cpu").load_state_dict(torch.load(model_cache_path))
+ self.state.untrained_size = 0
+
+ def train_incremental(
+ self,
+ features: List[np.ndarray],
+ results: np.ndarray,
+ ):
+ """Training on incremental data.
+
+ Parameters
+ ----------
+ features: List[np.ndarray]
+ The extracted features.
+ results: np.ndarray
+ The measured results.
+ """
+ results = np.min(results) / results
+ loader = SegmentDataLoader(features, results, batch_size=self.batch_size, shuffle=True)
+ self.optimizer = torch.optim.Adam(
+ self.state.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
+ )
+ self.state.model = self.state.model.to(self.device)
+ logger.info("Incremental training size: %d", len(loader))
+ for epoch in range(self.num_epoch_incremental):
+ logger.info("Epoch: %d", epoch)
+ self.state.model.train()
+ loss = None
+ for batch, data in enumerate(loader):
+ loss = self.train_step(data, batch, loss)
+ self.state.model.to("cpu")
+ self.state.untrained_size = max(0, self.state.untrained_size - len(loader))
+
+ def predict_incremental(
+ self,
+ features: List[np.ndarray],
+ results: np.ndarray = None,
+ ) -> np.ndarray:
+ """Predicting (validating) on incremental data.
+
+ Parameters
+ ----------
+ features: List[np.ndarray]
+ The extracted features.
+ results: np.ndarray
+ The measured results, can be None if used for predicting.
+
+ Returns
+ -------
+ pred_results: np.ndarray
+ The predicted results.
+ """
+ if results is not None:
+ results = np.min(results) / results
+ loader = SegmentDataLoader(features, results, batch_size=self.batch_size, shuffle=False)
+ self.state.model = self.state.model.to(self.device).eval()
+ logger.info("Incremental testing size: %d", len(loader))
+ pred_results, losses, scores = [], [], []
+ for data in loader:
+ pred_results_batch, losses_batch, scores_batch = self.predict_step(data)
+ pred_results.append(pred_results_batch)
+ losses.append(losses_batch)
+ scores.append(scores_batch)
+ pred_results = np.concatenate(pred_results)
+ if results is not None:
+ losses = np.array(losses[:-1]).mean() if len(losses) > 1 else losses[0]
+ logger.info(
+ "Average test loss: %6f, top1 score: %5f, top5 score: %5f, top10 score: %5f",
+ losses,
+ np.array(scores)[:, 0].mean(),
+ np.array(scores)[:, 1].mean(),
+ np.array(scores)[:, 2].mean(),
+ )
+ return pred_results
+
+ def update(
+ self,
+ features: List[np.ndarray],
+ costs: np.ndarray,
+ group_hash: str,
+ ):
+ """Update the dataset and re-train the model if not frozen.
+
+ Parameters
+ ----------
+ features: List[np.ndarray]
+ The extracted features.
+ costs: np.ndarray
+ The measured results.
+ group_hash: str
+ The hash of the group.
+ """
+ self.state.add_to_group(features, costs, group_hash)
+ if not self.frozen:
+ self.predict_incremental(features, costs)
+ if self.state.untrained_size / self.state.data_size > 0.2:
+ self.train_full()
+ else:
+ self.train_incremental(features, costs)
+
+
+@derived_object
+class MLPModel(PyCostModel):
+ """Segment Sum MLP Model
+
+ Parameters
+ ----------
+ trainer: SegmentSumMLPTrainer
+ The trainer for the model, handling the training interface.
+ """
+
+ trainer: SegmentSumMLPTrainer
+
+ def __init__(
+ self,
+ *,
+ trainer: SegmentSumMLPTrainer = SegmentSumMLPTrainer(),
+ ):
+ super().__init__()
+ self.trainer = trainer
+
+ def load(self, path: str) -> None:
+ """Load the cost model, cached data or raw data from given file location.
+
+ Parameters
+ ----------
+ path : str
+ The file path.
+ """
+ self.trainer.state.load(path)
+
+ def save(self, path: str) -> None:
+ """Save the cost model and data to given file location.
+
+ Parameters
+ ----------
+ path : str
+ The file path.
+ """
+ self.trainer.state.save(path)
+
+ def update(
+ self,
+ context: TuneContext,
+ candidates: List[MeasureCandidate],
+ results: List[RunnerResult],
+ ) -> None:
+ """Update the dataset, re-train the cost model if not frozen.
+
+ Parameters
+ ----------
+ context : TuneContext,
+ The tuning context.
+ candidates : List[MeasureCandidate]
+ The measure candidates.
+ results : List[RunnerResult]
+ The running results of the measure candidates.
+ """
+ features, mean_costs = extract_features(
+ context, candidates, results, self.trainer.state.extractor
+ )
+ self.trainer.update(features, mean_costs, shash2hex(context.mod))
+
+ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
+ """Predict given the measure candidates.
+
+ Parameters
+ ----------
+ context : TuneContext,
+ The tuning context.
+ candidates : List[MeasureCandidate]
+ The measure candidates.
+
+ Return
+ ------
+ result : np.ndarray
+ The predicted normalized score.
+ """
+ features, _ = extract_features(context, candidates, None, self.trainer.state.extractor)
+ pred_results = self.trainer.predict_incremental(features)
+ return pred_results
diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc
index a55ffa8b28..f8fb64e924 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -203,7 +203,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record,
} catch (std::runtime_error& e) {
LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1)
<< " of file " << path_tuning_record << ". The workload is:\n"
- << (workload.defined() ? tir::AsTVMScript(workload) : "(null)")
+ << (workload.defined() ? tir::AsTVMScript(workload->mod) : "(null)")
<< "\nThe JSONObject of TuningRecord is:\n"
<< json_obj << "\nThe error message is:\n"
<< e.what();