Kathryn-cat commented on code in PR #11961: URL: https://github.com/apache/tvm/pull/11961#discussion_r914340967
########## python/tvm/meta_schedule/cost_model/mlp_model.py: ########## @@ -0,0 +1,743 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +MLP-based cost model +""" + +import logging +import math +import os +import random +import tempfile +from collections import OrderedDict +from itertools import chain as itertools_chain +from typing import Dict, List, NamedTuple +from tqdm import tqdm + +import numpy as np +import torch +from torch import nn + +# pylint: disable=relative-beyond-top-level +from ...contrib.tar import tar, untar +from ...runtime import NDArray +from ..cost_model import PyCostModel +from ..feature_extractor import FeatureExtractor, PerStoreFeature +from ..runner import RunnerResult +from ..search_strategy import MeasureCandidate +from ..tune_context import TuneContext +from ..utils import derived_object, shash2hex + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +# pylint: disable=no-member + + +class SegmentSumMLPConfig(NamedTuple): + """SegmentSum MLP model configuration + + Parameters + ---------- + input_dim : int + The input dim for the model. + hidden_dim : int + The hidden dim for the model. + output_dim : int + The output dim for the model. + use_norm : bool + Whether to normalize the segment sum or not. + use_sigmoid : bool + Whether to use sigmoid on the final output or not. + """ + + input_dim: int = 172 + hidden_dim: int = 256 + output_dim: int = 1 + use_norm: bool = False + use_sigmoid: bool = False + + def to_dict(self): # pylint: disable=missing-function-docstring + return { + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "use_norm": self.use_norm, + "use_sigmoid": self.use_sigmoid, + } + + +# pylint: disable=too-few-public-methods +class FeatureGroup: + """Feature group + + Parameters + ---------- + group_hash : str + The hash of the group + features : List[np.ndarray] + The features + costs : List[float] + The costs + min_cost : float + The minimum cost + """ + + group_hash: str + features: List[np.ndarray] + costs: np.ndarray + min_cost: float + + def __init__( + self, + group_hash: str, + features: List[np.ndarray], + costs: np.ndarray, + ) -> None: + self.group_hash = group_hash + self.features = features + self.costs = costs + self.min_cost = np.min(costs) + + def append( # pylint: disable=missing-function-docstring + self, + features: List[np.ndarray], + costs: np.ndarray, + ) -> None: + self.features.extend(features) + self.costs = np.append(self.costs, costs) + self.min_cost = np.min(self.costs) + + +# pylint: disable=too-many-instance-attributes +class SegmentDataLoader: + """Dataloader for SegmentSum MLP model. + + Parameters + ---------- + features : List[np.ndarray] + The features + results : np.ndarray + The measured results + batch_size : int + The batch size + shuffle : bool + Whether to shuffle the dataset or not + """ + + def __init__( + self, + features, + results, + batch_size=128, + shuffle=False, + ): + self.batch_size = batch_size + self.shuffle = shuffle + self.data_size = len(features) + + # flatten features and store the starting indices + self.segment_sizes = torch.tensor([len(feature) for feature in features]) + self.feature_offsets = ( + torch.cumsum(self.segment_sizes, 0, dtype=torch.int32) - self.segment_sizes + ) + features = torch.cat([torch.tensor(feature) for feature in features]) + norm = features.max(dim=0)[0] + norm[norm == 0] = 1 + self.features = features / norm + self.results = torch.tensor(results) + self.iter_order = self.pointer = None + + def __len__(self): + return self.data_size + + def __iter__(self): + if self.shuffle: + self.iter_order = torch.randperm(self.data_size) + else: + self.iter_order = torch.arange(self.data_size) + self.pointer = 0 + return self + + def __next__(self): + if self.pointer >= self.data_size: + raise StopIteration + batch_indices = self.iter_order[self.pointer : self.pointer + self.batch_size] + self.pointer += self.batch_size + return self._fetch_indices(batch_indices) + + def _fetch_indices(self, indices): + segment_sizes, feature_offsets = self.segment_sizes[indices], self.feature_offsets[indices] + feature_indices = torch.empty(segment_sizes.sum(), dtype=torch.int32) + idx = 0 + for offset, seg_size in zip(feature_offsets, segment_sizes): + feature_indices[idx : idx + seg_size] = torch.arange(offset, offset + seg_size) + idx += seg_size + features = self.features[feature_indices.long()] + results = self.results[indices.long()] + return segment_sizes, features, results + + +class SegmentSumMLPModule(nn.Module): + """SegmentSum MLP model. + + Parameters + ---------- + input_dim : int + The input dim for the model. + hidden_dim : int + The hidden dim for the model. + output_dim : int + The output dim for the model. + use_norm : bool + Whether to normalize the segment sum or not. + use_sigmoid : bool + Whether to use sigmoid on the final output or not. + """ + + input_dim: int + hidden_dim: int + output_dim: int + use_norm: bool + use_sigmoid: bool + + def __init__( # pylint: disable=too-many-arguments + self, + input_dim: int = 172, + hidden_dim: int = 256, + output_dim: int = 1, + use_norm: bool = False, + use_sigmoid: bool = False, + ): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + ) + self.norm = nn.BatchNorm1d(hidden_dim) if use_norm else nn.Identity() + self.layer0 = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + ) + self.layer1 = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + ) + self.decoder = nn.Linear(hidden_dim, output_dim) + self.sigmoid = nn.Sigmoid() if use_sigmoid else nn.Identity() + + def forward( + self, + segment_sizes: torch.Tensor, + features: torch.Tensor, + ): + """Forward the inputs with the model. + + Parameters + ---------- + segment_sizes : Tensor + The sizes of the segments. + features : Tensor + The feature vectors of the candidates. + """ + n_seg = len(segment_sizes) + encoded_features = self.encoder(features) + segment_indices = torch.repeat_interleave( + torch.arange(n_seg, device=features.device), + segment_sizes.long(), + ) + n_dim = encoded_features.shape[1] + segment_sum = torch.scatter_add( + input=torch.zeros((n_seg, n_dim), dtype=encoded_features.dtype, device=features.device), + dim=0, + index=segment_indices.view(-1, 1).expand(-1, n_dim), + src=encoded_features, + ) + out = self.norm(segment_sum) + out = self.layer0(out) + out + out = self.layer1(out) + out + out = self.decoder(out).squeeze() + out = self.sigmoid(out) + return out + + # pylint: disable=no-self-use,too-many-arguments,too-many-locals,invalid-name + def lambda_rank_loss(self, preds, labels, k=None, eps=1e-10, sigma=1.0): + """LambdaLoss: Metric-Driven Loss for Learning-to-Rank + + Parameters + ---------- + preds : Tensor + The predicted runtime for each candidate. + labels : Tensor + The measured runtime for each candidate. + k : int + Loss for top k. + Default is None, which means computing all scores. + eps : float + The minimum value to the denominator and argument of log if they reach 0. + sigma : float + The scaling factor to the input of the sigmoid function. + """ + device = preds.device + y_pred, y_true = preds[None, :], labels[None, :] + y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1) + y_true_sorted, _ = y_true.sort(descending=True, dim=-1) + true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred) + true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :] + padded_pairs_mask = torch.isfinite(true_diffs) & (true_diffs > 0) + ndcg_at_k_mask = torch.zeros( + (y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device + ) + ndcg_at_k_mask[:k, :k] = 1 + true_sorted_by_preds.clamp_(min=0.0) + y_true_sorted.clamp_(min=0.0) + pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device) + D = torch.log2(1.0 + pos_idxs.float())[None, :] + maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps) + G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None] + weights = torch.abs( + torch.pow(D[:, :, None], -1.0) - torch.pow(D[:, None, :], -1.0) + ) * torch.abs(G[:, :, None] - G[:, None, :]) + scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp( + min=-1e8, max=1e8 + ) + scores_diffs[torch.isnan(scores_diffs)] = 0.0 + weighted_probs = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp( + min=eps + ) + losses = torch.log2(weighted_probs) + masked_losses = losses[padded_pairs_mask & ndcg_at_k_mask] + loss = -torch.sum(masked_losses) + return loss + + def topk_score( + self, + pred_results: torch.Tensor, + gt_results: torch.Tensor, + k: int, + ) -> float: + """ + Evaluate the top-k score + + Parameters + ---------- + pred_results: Tensor + The raw prediction + gt_results: Tensor + The measured results + k : int + The k in top k score + + Returns + ------- + score : float + The top-k score + """ + topk_indices = torch.topk(pred_results, k, largest=False).indices + score = gt_results.min() / gt_results[topk_indices].min() + return score.item() + + +@derived_object +class MLPModel(PyCostModel): Review Comment: Updated a draft for this.. Perhaps can discuss a bit? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
